working ok
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
# Agent guide — voice-chat
|
||||
|
||||
Orientation for AI agents making changes in this repo. Keep edits scoped; this pipeline has several sharp edges that a naive refactor will hit.
|
||||
|
||||
## What this project is
|
||||
|
||||
A local real-time voice (and optionally video-avatar) chat server. FastAPI + WebSocket on the backend, a small vanilla-JS UI on the frontend. All ML runs locally on a single NVIDIA GPU; there are no cloud API calls at runtime.
|
||||
|
||||
## Two-tier architecture
|
||||
|
||||
1. **Audio pipeline** (always on): mic PCM → [vad.py](server/vad.py) → [asr.py](server/asr.py) → [llm.py](server/llm.py) → [tts.py](server/tts.py) → PCM back to the browser. Orchestrated by [pipeline.py](server/pipeline.py) via `ConversationSession`. Supports barge-in (user speaks → cancel in-flight reply).
|
||||
|
||||
2. **Video pipeline** (optional, gated by `config.video.enabled`): per assistant turn, [video.py](server/video.py)'s `VideoEngine` calls [video_models/wan22.py](server/video_models/wan22.py) (LightX2V Wan2.2-I2V) to produce base frames, then [video_models/musetalk.py](server/video_models/musetalk.py) to lip-sync them to the TTS audio, then [video_models/muxer.py](server/video_models/muxer.py) to produce an MP4. The MP4 is sent over the same `/ws/chat` WebSocket as a `speaking_clip` message.
|
||||
|
||||
The audio path must keep working when `video.enabled` is false. Don't make video models load-bearing for the audio pipeline.
|
||||
|
||||
## Key files
|
||||
|
||||
- [server/main.py](server/main.py) — FastAPI app, WebSocket, video/avatar HTTP endpoints
|
||||
- [server/models.py](server/models.py) — lifetime of all models; `ModelManager.video_engine` is `None` when video is disabled
|
||||
- [server/pipeline.py](server/pipeline.py) — per-session audio pipeline + video branch
|
||||
- [server/config.py](server/config.py) — parses [config.yml](config.yml)
|
||||
- [server/video.py](server/video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine` (library vs reflective modes)
|
||||
- [server/video_models/wan22.py](server/video_models/wan22.py) — LightX2V wrapper; fp8 + GGUF loading; Blackwell patches
|
||||
- [configs/lightx2v/](configs/lightx2v/) — LightX2V inference config templates; must match `wan22_dit_quant_scheme`
|
||||
- [tests/unit/](tests/unit/) — GPU-free tests, runnable on Windows host
|
||||
- [tests/component/](tests/component/) — end-to-end tests, must run inside the Docker container
|
||||
|
||||
## Conventions
|
||||
|
||||
- Config: single source of truth is [config.yml](config.yml) → `server/config.py` dataclasses. Don't read env vars for runtime behaviour; if you need a new knob, add it to the dataclass and document it in `config.yml`.
|
||||
- Logging: `log = logging.getLogger(__name__)` at module top; log level is set once in `server/main.py`. INFO for lifecycle, DEBUG for per-chunk chatter.
|
||||
- Async: WebSocket handlers and endpoints are async, but heavy model work is sync — wrap via `asyncio.to_thread(...)` at the call site (see `set_avatar` in `main.py`).
|
||||
- Concurrency: `VideoEngine` serialises generation with `self._lock`. Don't call model methods without holding it from another thread.
|
||||
- Tests: every non-trivial logic change in `server/video.py` or `server/pipeline.py` should have a corresponding `tests/unit/` test. GPU-dependent behaviour goes in `tests/component/`.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **GPU architecture.** This is tuned for RTX 5090 / Blackwell (SM120) with PyTorch 2.8 + Triton 3.4. Several upstream kernels (flashinfer, flash_attn3, sgl_kernel fp8 matmul, Triton-fused scale/shift) are broken or unavailable there. See [server/video_models/AGENT.md](server/video_models/AGENT.md) before touching the Wan2.2 wrapper.
|
||||
- **First launch is slow.** Hugging Face downloads land in the `huggingface-cache` Docker volume; a cold run pulls >20 GB.
|
||||
- **Wan-AI/Wan2.2-I2V-A14B** ships bf16 DIT shards we don't want — `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](server/video_models/wan22.py) excludes them. Keep that list in sync if the repo layout changes.
|
||||
- **LoRA targets matter.** Wan2.2 is a MoE (high_noise + low_noise sub-models). A LoRA with the wrong `target` loads silently and produces subtly wrong output.
|
||||
- **Don't mix audio+video state.** The audio pipeline must not block on video generation; video is produced for a turn *after* the full reply audio is available, and sent as a separate message.
|
||||
|
||||
## When in doubt
|
||||
|
||||
- Run `python -m pytest tests/unit -v` — it's fast and catches most regressions.
|
||||
- For GPU changes, run the lowest-numbered relevant component test first; they're ordered to isolate failure to a single stage.
|
||||
- Check [memory](../../.claude/projects/c--Users-bheth-Documents-voice-chat/memory/) (auto-loaded) for prior-session findings that aren't in the code.
|
||||
+12
-3
@@ -61,10 +61,19 @@ RUN python3.11 -m pip install --no-cache-dir --no-deps \
|
||||
"sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \
|
||||
echo "sgl-kernel install failed — fp8 T5 will fall back to bf16"
|
||||
#
|
||||
# MuseTalk (audio-driven lip-sync) — same story.
|
||||
# MuseTalk (audio-driven lip-sync) — installed from the bhetherman/MuseTalk
|
||||
# fork checked in as a submodule at third_party/MuseTalk. The upstream repo
|
||||
# has no setup.py / pyproject.toml; our fork adds them so `pip install .`
|
||||
# just works. We deliberately do NOT install its requirements.txt (it pins
|
||||
# numpy==1.23.5, transformers==4.39.2, tensorflow==2.12.0 which conflict
|
||||
# with the rest of the stack) — instead we install its real runtime deps
|
||||
# explicitly here.
|
||||
COPY third_party/MuseTalk /opt/MuseTalk
|
||||
RUN python3.11 -m pip install --no-cache-dir --no-deps /opt/MuseTalk || \
|
||||
echo "MuseTalk install failed — config.video.musetalk.enabled must stay false until fixed"
|
||||
RUN python3.11 -m pip install --no-cache-dir \
|
||||
"git+https://github.com/TMElyralab/MuseTalk.git" || \
|
||||
echo "MuseTalk install failed — config.video.enabled must stay false until fixed"
|
||||
librosa einops omegaconf ffmpeg-python || \
|
||||
echo "MuseTalk runtime deps install failed"
|
||||
#
|
||||
# LoRA directory (user drops .safetensors here; bind-mounted in compose).
|
||||
RUN mkdir -p /cache/loras
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
# Voice Chat
|
||||
|
||||
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — all running on your own GPU with no cloud APIs.
|
||||
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — and optionally a lip-synced talking-head video of a chosen avatar — all running on your own GPU with no cloud APIs.
|
||||
|
||||
## Pipeline
|
||||
|
||||
**Mic input** → **VAD** (Silero ONNX) → **ASR** (Qwen3-ASR-0.6B) → **LLM** (Qwen3.5-0.8B) → **TTS** (Kokoro) → **Speaker output**
|
||||
|
||||
When the optional video stack is enabled, each assistant turn also produces an MP4 via:
|
||||
|
||||
**TTS audio + avatar image** → **Wan2.2-Lightning I2V** (LightX2V, fp8 or GGUF) → **MuseTalk lip-sync** → **ffmpeg mux** → **`speaking_clip` WebSocket message**
|
||||
|
||||
- **VAD** — Silero VAD via ONNX Runtime, detects speech/silence boundaries on CPU
|
||||
- **ASR** — Qwen3-ASR-0.6B, bfloat16 on CUDA
|
||||
- **LLM** — Qwen3.5-0.8B, loaded via transformers
|
||||
- **LLM** — Qwen3.5-0.8B (local) or any model served by LM Studio
|
||||
- **TTS** — Kokoro, streams sentence-by-sentence audio at 24 kHz
|
||||
- **Barge-in** — interrupt the assistant mid-response by speaking
|
||||
- **Video (optional)** — Wan2.2 I2V 14B MoE with LightX2V distill LoRAs, fp8 or GGUF-quantised DIT, lip-synced by MuseTalk. Two modes:
|
||||
- `library` — pre-bakes a small set of speaking base clips per avatar, picks round-robin per turn, lip-syncs on the fly
|
||||
- `reflective` — generates a fresh Wan2.2 clip per turn from a prompt derived from the assistant's reply
|
||||
|
||||
## Requirements
|
||||
|
||||
- NVIDIA GPU with CUDA 12.8 support
|
||||
- NVIDIA GPU with CUDA 12.8 support (tested on RTX 5090 / SM120 Blackwell)
|
||||
- Docker + Docker Compose with the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
||||
- ~24 GB VRAM recommended when video is enabled (fp8); ~16 GB with `gguf-Q4_K_M`
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -27,6 +35,17 @@ Then open [http://localhost:8000](http://localhost:8000) in your browser.
|
||||
|
||||
Models are downloaded from Hugging Face on first launch and cached in a Docker volume (`huggingface-cache`) so they persist across rebuilds.
|
||||
|
||||
## Configuration
|
||||
|
||||
All runtime behaviour is driven by [config.yml](config.yml):
|
||||
|
||||
- `llm.backend` — `local` (in-process transformers) or `lmstudio` (talks to a local LM Studio server)
|
||||
- `llm.system_prompt`, `llm.max_cache_tokens` — prompt and KV-cache limit per session
|
||||
- `video.enabled` — master toggle for the avatar video stack. When `false`, no video models load and the app behaves exactly like the audio-only pipeline
|
||||
- `video.mode` — `library` or `reflective`
|
||||
- `video.models.wan22_dit_quant_scheme` — `gguf-Q8_0` (default) or `gguf-Q4_K_M` for lower VRAM; any GGUF level LightX2V supports (dense 5B Turbo is GGUF-only)
|
||||
- `video.loras` — list of LoRA adapters applied to the dense Wan2.2-TI2V-5B DIT at load time. Each entry has `path`, `weight`, `target` (always `both` — the 5B DIT is not MoE; legacy `high_noise`/`low_noise` values are coerced), and optional `name`. User LoRAs are mounted from `./loras/` into the container at `/cache/loras/`
|
||||
|
||||
## Local Development (without Docker)
|
||||
|
||||
```bash
|
||||
@@ -45,17 +64,41 @@ python run.py
|
||||
|
||||
The server starts on port 8000.
|
||||
|
||||
## API
|
||||
|
||||
- `GET /` — browser UI
|
||||
- `WS /ws/chat` — bidirectional audio + control WebSocket
|
||||
- `POST /api/set-voice` — switch the Kokoro voice
|
||||
- `POST /api/set-avatar` *(video only)* — upload an avatar image; (re)generates idle + library clips
|
||||
- `GET /api/idle-clip` *(video only)* — cached idle loop MP4
|
||||
- `POST /api/set-video-mode` *(video only)* — switch between `off` / `library` / `reflective`
|
||||
- `POST /api/reload-loras` *(video only)* — hot-swap the LoRA stack; regenerates the idle clip
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
server/
|
||||
main.py — FastAPI app, WebSocket endpoint
|
||||
models.py — Model loading and management
|
||||
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration
|
||||
vad.py — Silero VAD (ONNX) streaming wrapper
|
||||
asr.py — Speech recognition engine
|
||||
llm.py — Language model engine
|
||||
tts.py — Kokoro TTS engine
|
||||
audio_utils.py — PCM/float32 conversion helpers
|
||||
static/ — Browser UI (HTML/JS/CSS)
|
||||
main.py — FastAPI app, WebSocket + video endpoints
|
||||
models.py — Model loading and management (audio + optional video)
|
||||
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration, video branch
|
||||
config.py — config.yml parsing
|
||||
vad.py — Silero VAD (ONNX) streaming wrapper
|
||||
asr.py — Speech recognition engine
|
||||
llm.py — Language model engine (local + LM Studio backends)
|
||||
tts.py — Kokoro TTS engine
|
||||
audio_utils.py — PCM/float32 conversion helpers
|
||||
video.py — VideoEngine orchestrator + VideoConfig + LoRASpec
|
||||
video_models/
|
||||
wan22.py — LightX2V Wan2.2 I2V pipeline wrapper (fp8 + GGUF)
|
||||
musetalk.py — MuseTalk lip-sync wrapper
|
||||
muxer.py — ffmpeg helpers: frames -> MP4, frames+audio -> MP4
|
||||
|
||||
configs/
|
||||
lightx2v/ — LightX2V inference config templates (fp8 + GGUF variants)
|
||||
|
||||
static/ — Browser UI (HTML/JS/CSS)
|
||||
avatars/ — uploaded avatar images (gitignored)
|
||||
loras/ — user-supplied Wan2.2 LoRAs, mounted into the container
|
||||
reference_audio/ — Kokoro voice reference samples
|
||||
tests/ — unit + component tests (see tests/README.md)
|
||||
```
|
||||
|
||||
+33
-29
@@ -13,9 +13,9 @@ llm:
|
||||
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
|
||||
model: "" # leave empty to use whatever model LM Studio has loaded
|
||||
|
||||
# Avatar video generation (Wan2.2-Lightning fp8 via LightX2V + MuseTalk lip-sync)
|
||||
# Avatar video generation (Wan2.2-TI2V-5B-Turbo GGUF via LightX2V + MuseTalk lip-sync)
|
||||
video:
|
||||
enabled: false # master toggle — when false, video models are not loaded
|
||||
enabled: true # master toggle — when false, video models are not loaded
|
||||
backend: lightx2v # only option for now
|
||||
mode: reflective # "library" (pre-baked clips) | "reflective" (fresh per turn)
|
||||
resolution: 480 # 480 or 720
|
||||
@@ -25,6 +25,12 @@ video:
|
||||
base_clip_count: 4 # how many speaking base clips to pre-generate per avatar
|
||||
base_clip_seconds: 6 # duration of each pre-baked clip
|
||||
|
||||
# MuseTalk audio-driven lip-sync. When disabled, Wan2.2 base frames are
|
||||
# used as-is without a lip-sync pass — useful when MuseTalk isn't installed
|
||||
# or while iterating on the base pipeline.
|
||||
musetalk:
|
||||
enabled: false # toggle lip-sync on/off
|
||||
|
||||
reflective:
|
||||
clip_seconds: 5 # target length of each fresh Wan2.2 clip per turn
|
||||
clip_prompt_template: >-
|
||||
@@ -33,35 +39,33 @@ video:
|
||||
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
|
||||
|
||||
# Model sources for the video stack. T5/VAE/tokenizer come from the
|
||||
# Wan-AI base repo. DIT weights come from wan22_dit_repo in the format
|
||||
# specified by wan22_dit_quant_scheme. Both repos download on first run
|
||||
# into HF_HOME=/cache/huggingface.
|
||||
# Wan-AI base repo. The single dense DIT comes from wan22_dit_repo as
|
||||
# GGUF (Turbo 4-step distill). Both repos download on first run into
|
||||
# HF_HOME=/cache/huggingface.
|
||||
#
|
||||
# Supported dit_quant_scheme values:
|
||||
# fp8-sgl — fp8 e4m3 safetensors (~15 GB/expert, from lightx2v/Wan2.2-Distill-Models)
|
||||
# gguf-Q4_K_M — GGUF 4-bit (~9.65 GB/expert, from QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||
# gguf-Q8_0 — GGUF 8-bit (~15.4 GB/expert)
|
||||
# (any gguf-<level> supported by LightX2V — see base_model.py MM_WEIGHT_REGISTER)
|
||||
# Supported dit_quant_scheme values (dense 5B Turbo — GGUF only):
|
||||
# gguf-Q8_0 — 8-bit, ~6 GB DIT, ~6.5 GB VRAM at load (default)
|
||||
# gguf-Q4_K_M — 4-bit, ~3.5 GB DIT, lower VRAM for tight budgets
|
||||
# (any gguf-<level> published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF)
|
||||
models:
|
||||
wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B
|
||||
wan22_dit_repo: QuantStack/Wan2.2-I2V-A14B-GGUF
|
||||
wan22_dit_quant_scheme: gguf-Q4_K_M
|
||||
wan22_base_repo: Wan-AI/Wan2.2-TI2V-5B
|
||||
wan22_dit_repo: hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF
|
||||
wan22_dit_quant_scheme: gguf-Q8_0
|
||||
wan22_t5_quantized: true
|
||||
wan22_model_cls: wan2.2_moe_distill
|
||||
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_distill.json
|
||||
wan22_model_cls: wan2.2
|
||||
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json
|
||||
musetalk_path: TMElyralab/MuseTalk
|
||||
|
||||
# LoRAs applied to the fp8 base at load time via runtime switch_lora.
|
||||
# Wan2.2 is a MoE with separate high-noise and low-noise sub-models —
|
||||
# `target` picks which sub-model each LoRA attaches to. The two files
|
||||
# below are the user-supplied ./loras/wan22-[HL]-e8.safetensors mounted
|
||||
# into the container at /cache/loras/.
|
||||
loras:
|
||||
- path: /cache/loras/wan22-H-e8.safetensors
|
||||
weight: 1.0
|
||||
target: high_noise
|
||||
name: wan22-H-e8
|
||||
- path: /cache/loras/wan22-L-e8.safetensors
|
||||
weight: 1.0
|
||||
target: low_noise
|
||||
name: wan22-L-e8
|
||||
# LoRAs applied to the dense 5B DIT at load time via LightX2V's
|
||||
# lora_dynamic_apply path (merged during GGUF dequant). Dense has a
|
||||
# single set of weights so `target` is always `both`.
|
||||
#
|
||||
# The old MoE-trained wan22-H-e8 / wan22-L-e8 LoRAs are NOT compatible
|
||||
# with the 5B DIT and are disabled here. Future 5B-compatible LoRAs
|
||||
# should follow the shape shown below.
|
||||
loras: []
|
||||
# loras:
|
||||
# - path: /cache/loras/your-5b-lora.safetensors
|
||||
# weight: 1.0
|
||||
# target: both
|
||||
# name: your-5b-lora
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
{
|
||||
"_comment": "Wan2.2 i2v MoE 4-step distill, fp8 e4m3 quantized. Built for 24 GB-class GPUs — cpu_offload keeps DIT layers swapping in block-by-block. Derived from LightX2V's configs/distill/wan22/wan_moe_i2v_distill_4090.json plus the quant scheme + ckpt overrides from wan_moe_i2v_distill_quant.json. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py with absolute paths to the files downloaded into HF_HOME.",
|
||||
|
||||
"infer_steps": 4,
|
||||
"target_video_length": 81,
|
||||
"text_len": 512,
|
||||
|
||||
"resize_mode": "adaptive",
|
||||
"resolution": "480p",
|
||||
"target_height": 480,
|
||||
"target_width": 480,
|
||||
"fps": 16,
|
||||
|
||||
"self_attn_1_type": "flash_attn3",
|
||||
"cross_attn_1_type": "flash_attn3",
|
||||
"cross_attn_2_type": "flash_attn3",
|
||||
|
||||
"sample_guide_scale": [3.5, 3.5],
|
||||
"sample_shift": 5.0,
|
||||
"enable_cfg": false,
|
||||
|
||||
"cpu_offload": true,
|
||||
"offload_granularity": "block",
|
||||
"lazy_load": true,
|
||||
"t5_cpu_offload": true,
|
||||
"vae_cpu_offload": false,
|
||||
|
||||
"use_image_encoder": false,
|
||||
|
||||
"boundary_step_index": 2,
|
||||
"denoising_step_list": [1000, 750, 500, 250],
|
||||
|
||||
"dit_quantized": true,
|
||||
"dit_quant_scheme": "fp8-sgl",
|
||||
"t5_quantized": false
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by Wan22Pipeline.",
|
||||
|
||||
"infer_steps": 4,
|
||||
"target_video_length": 81,
|
||||
"text_len": 512,
|
||||
|
||||
"resize_mode": "adaptive",
|
||||
"resolution": "480p",
|
||||
"target_height": 480,
|
||||
"target_width": 480,
|
||||
"fps": 16,
|
||||
|
||||
"vae_stride": [4, 16, 16],
|
||||
"num_channels_latents": 48,
|
||||
|
||||
"self_attn_1_type": "torch_sdpa",
|
||||
"cross_attn_1_type": "torch_sdpa",
|
||||
"cross_attn_2_type": "torch_sdpa",
|
||||
"modulate_type": "torch",
|
||||
"rope_type": "torch",
|
||||
|
||||
"sample_guide_scale": 1.0,
|
||||
"sample_shift": 5.0,
|
||||
"enable_cfg": false,
|
||||
|
||||
"cpu_offload": false,
|
||||
"offload_granularity": "model",
|
||||
"t5_cpu_offload": true,
|
||||
"vae_cpu_offload": false,
|
||||
|
||||
"use_image_encoder": false,
|
||||
|
||||
"denoising_step_list": [1000, 750, 500, 250],
|
||||
|
||||
"dit_quantized": true,
|
||||
"dit_quant_scheme": "gguf-Q8_0",
|
||||
"t5_quantized": true,
|
||||
"t5_quant_scheme": "fp8-sgl"
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
{
|
||||
"_comment": "Wan2.2 i2v MoE 4-step distill, GGUF quantized. Uses QuantStack/Wan2.2-I2V-A14B-GGUF checkpoints instead of fp8 safetensors. GGUF does not support block-level offload so offload_granularity is set to 'model' — the entire DIT is moved to GPU when active. With Q4_K_M (~9.65 GB per expert) this fits comfortably in 24+ GB VRAM. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py. IMPORTANT: GGUF dequantizes to fp16, so you must set DTYPE=FP16 in the container environment.",
|
||||
|
||||
"infer_steps": 4,
|
||||
"target_video_length": 81,
|
||||
"text_len": 512,
|
||||
|
||||
"resize_mode": "adaptive",
|
||||
"resolution": "480p",
|
||||
"target_height": 480,
|
||||
"target_width": 480,
|
||||
"fps": 16,
|
||||
|
||||
"_comment_attn": "flash_attn3/sageattn3 aren't installed (no Blackwell-ready pre-built wheels). Use PyTorch SDPA which works on SM120.",
|
||||
"self_attn_1_type": "torch_sdpa",
|
||||
"cross_attn_1_type": "torch_sdpa",
|
||||
"cross_attn_2_type": "torch_sdpa",
|
||||
|
||||
"_comment_modulate": "Triton fuse_scale_shift_kernel segfaults during JIT compile on Blackwell SM120 (triton 3.4 + cu128). Use the PyTorch modulate fallback until the Triton issue is resolved.",
|
||||
"modulate_type": "torch",
|
||||
"_comment_rope": "flashinfer not installed; fall back to PyTorch rope.",
|
||||
"rope_type": "torch",
|
||||
|
||||
"sample_guide_scale": [3.5, 3.5],
|
||||
"sample_shift": 5.0,
|
||||
"enable_cfg": false,
|
||||
|
||||
"cpu_offload": true,
|
||||
"offload_granularity": "model",
|
||||
"t5_cpu_offload": true,
|
||||
"vae_cpu_offload": false,
|
||||
|
||||
"use_image_encoder": false,
|
||||
|
||||
"boundary_step_index": 2,
|
||||
"denoising_step_list": [1000, 750, 500, 250],
|
||||
|
||||
"dit_quantized": true,
|
||||
"dit_quant_scheme": "gguf-Q4_K_M",
|
||||
"t5_quantized": false
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
# Agent guide — server/
|
||||
|
||||
The audio pipeline and FastAPI surface. The video stack lives under [video_models/](video_models/) and has its own guide.
|
||||
|
||||
## Module map
|
||||
|
||||
- [main.py](main.py) — FastAPI app, lifespan, `/ws/chat` WebSocket, video/avatar HTTP endpoints. Keep business logic out of here; this is a transport layer.
|
||||
- [models.py](models.py) — `ModelManager.load_all()`. All models are loaded once at startup in a fixed order: VAD → ASR → LLM → TTS → (optional) Video. `video_engine` stays `None` when `config.video.enabled` is false — callers MUST tolerate that.
|
||||
- [config.py](config.py) — thin YAML loader, exposes a single `config` dict. Don't scatter `yaml.safe_load` elsewhere.
|
||||
- [pipeline.py](pipeline.py) — `ConversationSession`, one instance per WebSocket. Owns per-session state (VAD stream, conversation history, KV cache, cancel event). Orchestrates VAD → ASR → LLM → TTS and optionally the video branch.
|
||||
- [vad.py](vad.py) — Silero VAD via ONNX Runtime on CPU. `StreamingVAD.process_chunk(pcm_16k) → utterance | None`. Returns a full utterance only on speech→silence transition.
|
||||
- [asr.py](asr.py) — Qwen3-ASR wrapper. Sync, called under `asyncio.to_thread`.
|
||||
- [llm.py](llm.py) — two backends behind a common `generate(history, max_new_tokens, kv_cache_state) → (text, KVCacheState)` signature: `LLMEngine` (local transformers) and `LMStudioEngine` (HTTP, no KV cache).
|
||||
- [tts.py](tts.py) — Kokoro wrapper. The per-segment generator yields `(graphemes, _ps, audio_f32)` tuples at 24 kHz.
|
||||
- [audio_utils.py](audio_utils.py) — `pcm_bytes_to_float32` / `float32_to_pcm_bytes`. The WebSocket protocol is 16 kHz int16 PCM in, 24 kHz int16 PCM out.
|
||||
- [video.py](video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine`. Orchestrates Wan2.2 + MuseTalk. Gated by `config.video.enabled`.
|
||||
- [video_models/](video_models/) — see [video_models/AGENT.md](video_models/AGENT.md) for Blackwell/GGUF gotchas.
|
||||
|
||||
## Session lifecycle (pipeline.py)
|
||||
|
||||
1. Client connects → `ConversationSession(models, send_json, send_bytes)` → `start()` emits `{type: "status", state: "listening"}` and a `{type: "video_mode", ...}` hint.
|
||||
2. Inbound binary frames are 16 kHz int16 PCM → `handle_audio_chunk` → VAD. On speech→silence the session kicks off `_process_utterance` as an `asyncio.Task` so it never blocks the WebSocket receive loop.
|
||||
3. `_process_utterance` flow: ASR → LLM → TTS stream. Each blocking call wraps in `asyncio.to_thread`.
|
||||
4. TTS output is split into short segments via `_split_into_segments` before synthesis so streaming chunks stay small.
|
||||
5. TTS runs on a background `threading.Thread`, feeding a `queue.Queue` that the async loop drains.
|
||||
|
||||
## Barge-in
|
||||
|
||||
- `cancel_event: threading.Event` is the single stop signal. Checked between every stage and inside the TTS queue drain loop.
|
||||
- Two ways to trigger: new VAD utterance while `is_responding` (`handle_audio_chunk`), or an explicit `{type: "interrupt", last_chunk_id}` text message (`interrupt`).
|
||||
- On cancel: set the event, send `{type: "interrupt"}` to the client so it flushes its audio buffer, and discard any pending video clip.
|
||||
- Don't add work after `cancel_event.is_set()` without checking — that's how zombie audio/video reaches the client after a barge-in.
|
||||
|
||||
## Video branch
|
||||
|
||||
The audio pipeline changes shape when `video_engine.is_ready()`:
|
||||
|
||||
- PCM chunks and `response_text` are **not** streamed during the turn — they're buffered.
|
||||
- TTS audio is concatenated into one float32 array.
|
||||
- After TTS completes, `video_engine.generate_speaking_clip(audio, sr, reply_text)` renders the MP4 (blocking, wrapped in `to_thread`).
|
||||
- The full clip + final text is sent as a single `speaking_clip` message.
|
||||
|
||||
If you extend the audio pipeline, preserve this dual-mode behaviour: the client's UX is very different between the two paths, and mixing them (e.g. sending PCM *and* a clip) will double-play the audio.
|
||||
|
||||
## Conventions
|
||||
|
||||
- Dataclasses for structured config (see `VideoConfig`, `LoRASpec`). Parse once in `*.from_dict`; don't re-read `config.yml` mid-session.
|
||||
- `asyncio.to_thread` for any sync model call from an async context. Never call `.generate()` / `.transcribe()` / `.pipeline()` directly on the event loop.
|
||||
- Locks: `VideoEngine._lock` serialises model state mutations; `ConversationSession` is not locked because each WebSocket gets its own instance.
|
||||
- Logging: one `log = logging.getLogger(__name__)` per module. INFO for lifecycle and per-turn milestones; avoid DEBUG spam in hot loops.
|
||||
- Keep `main.py` thin. New endpoints should delegate to a method on `ModelManager` or an engine class.
|
||||
|
||||
## Testing
|
||||
|
||||
- `tests/unit/test_pipeline_video_branch.py` — the video vs. audio path selection. Update it if you change the `use_video` condition.
|
||||
- `tests/unit/test_video_config.py` / `test_video_engine_logic.py` — config parsing and the pure logic in `video.py`.
|
||||
- Component tests live in `tests/component/` and require the Docker GPU environment. See [tests/README.md](../tests/README.md).
|
||||
+1
-1
@@ -133,7 +133,7 @@ async def reload_loras(body: dict):
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
if target != "both":
|
||||
target = "both"
|
||||
specs.append(
|
||||
LoRASpec(
|
||||
|
||||
+1
-2
@@ -121,8 +121,7 @@ class ModelManager:
|
||||
log.info("Loading avatar video engine...")
|
||||
cfg = VideoConfig.from_dict(video_cfg_raw)
|
||||
self.video_engine = VideoEngine(cfg)
|
||||
if cfg.loras:
|
||||
self.video_engine.load_loras(cfg.loras)
|
||||
self.video_engine.load_models()
|
||||
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
|
||||
|
||||
def create_vad(self) -> StreamingVAD:
|
||||
|
||||
+49
-37
@@ -18,18 +18,18 @@ import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LoRATarget = Literal["high_noise", "low_noise", "both"]
|
||||
LoRATarget = Literal["both"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRASpec:
|
||||
"""One LoRA adapter entry from ``config.video.loras``.
|
||||
|
||||
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
|
||||
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
|
||||
sub-model) and must be applied to the correct target. Allow
|
||||
``target="both"`` for LoRAs that should be applied to both sub-models
|
||||
(e.g. style LoRAs).
|
||||
The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE
|
||||
experts), so ``target`` is always ``"both"``. The field is kept for
|
||||
forward compatibility and config-file compatibility with older MoE
|
||||
configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced
|
||||
to ``"both"`` in ``VideoConfig.from_dict``.
|
||||
"""
|
||||
|
||||
path: str
|
||||
@@ -60,18 +60,20 @@ class VideoConfig:
|
||||
# Model paths — can be overridden via config.yml.video.models.
|
||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||
# The bf16 DIT shards in this repo are skipped — we
|
||||
# replace them with quantised files from wan22_dit_repo.
|
||||
# wan22_dit_repo : HF repo id (or local dir) providing the quantised
|
||||
# DIT checkpoints (fp8 or GGUF).
|
||||
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
|
||||
# replace them with a quantised GGUF from wan22_dit_repo.
|
||||
# wan22_dit_repo : HF repo id (or local dir) providing the single
|
||||
# dense GGUF DIT checkpoint (5B Turbo).
|
||||
# wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default)
|
||||
# or "gguf-Q4_K_M" for lower VRAM.
|
||||
# wan22_config_json : path to the LightX2V inference config template the
|
||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
||||
wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||
wan22_dit_quant_scheme: str = "fp8-sgl"
|
||||
wan22_t5_quantized: bool = False
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
wan22_model_cls: str = "wan2.2_moe_distill"
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
wan22_dit_quant_scheme: str = "gguf-Q8_0"
|
||||
wan22_t5_quantized: bool = True
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
wan22_model_cls: str = "wan2.2"
|
||||
musetalk_enabled: bool = True
|
||||
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||
|
||||
@classmethod
|
||||
@@ -92,9 +94,10 @@ class VideoConfig:
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
if target != "both":
|
||||
log.warning(
|
||||
"LoRA %s: invalid target %r, defaulting to 'both'",
|
||||
"LoRA %s: target %r is MoE-era; coercing to 'both' "
|
||||
"(dense 5B has a single DIT).",
|
||||
entry.get("path"), target,
|
||||
)
|
||||
target = "both"
|
||||
@@ -122,30 +125,32 @@ class VideoConfig:
|
||||
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
|
||||
loras=loras,
|
||||
wan22_base_repo=str(
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B")
|
||||
),
|
||||
wan22_dit_repo=str(
|
||||
models_raw.get(
|
||||
"wan22_dit_repo",
|
||||
# Backwards compat: fall back to old key name.
|
||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
|
||||
"hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF",
|
||||
)
|
||||
),
|
||||
wan22_dit_quant_scheme=str(
|
||||
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
|
||||
models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0")
|
||||
),
|
||||
wan22_t5_quantized=bool(
|
||||
models_raw.get("wan22_t5_quantized", False)
|
||||
models_raw.get("wan22_t5_quantized", True)
|
||||
),
|
||||
wan22_config_json=str(
|
||||
models_raw.get(
|
||||
"wan22_config_json",
|
||||
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
|
||||
"/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json",
|
||||
)
|
||||
),
|
||||
wan22_model_cls=str(
|
||||
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
|
||||
models_raw.get("wan22_model_cls", "wan2.2")
|
||||
),
|
||||
musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True))
|
||||
if isinstance(raw.get("musetalk"), dict)
|
||||
else bool(raw.get("musetalk_enabled", True)),
|
||||
musetalk_model_path=str(
|
||||
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
|
||||
),
|
||||
@@ -235,17 +240,22 @@ class VideoEngine:
|
||||
self._wan22.load_loras(self.cfg.loras)
|
||||
log.info("Wan2.2 pipeline ready.")
|
||||
|
||||
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||
log.info("MuseTalk engine ready.")
|
||||
if self.cfg.musetalk_enabled:
|
||||
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||
log.info("MuseTalk engine ready.")
|
||||
else:
|
||||
log.info("MuseTalk disabled via config — skipping lip-sync pass.")
|
||||
self._musetalk = None
|
||||
|
||||
# --- Readiness ------------------------------------------------------
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True when an avatar is set and a speaking clip can be produced."""
|
||||
musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None
|
||||
return (
|
||||
self._wan22 is not None
|
||||
and self._musetalk is not None
|
||||
and musetalk_ok
|
||||
and self.avatar_path is not None
|
||||
and self.idle_clip_mp4 is not None
|
||||
)
|
||||
@@ -336,7 +346,6 @@ class VideoEngine:
|
||||
"(avatar set? models loaded?)"
|
||||
)
|
||||
assert self._wan22 is not None
|
||||
assert self._musetalk is not None
|
||||
|
||||
# 1. Source base frames.
|
||||
if self.cfg.mode == "library":
|
||||
@@ -351,13 +360,16 @@ class VideoEngine:
|
||||
seed=None, # random each turn
|
||||
)
|
||||
|
||||
# 2. Lip-sync the base frames to the given audio.
|
||||
synced_frames = self._musetalk.lip_sync(
|
||||
frames=base_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
# 2. Lip-sync the base frames to the given audio (if enabled).
|
||||
if self._musetalk is not None:
|
||||
synced_frames = self._musetalk.lip_sync(
|
||||
frames=base_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
else:
|
||||
synced_frames = base_frames
|
||||
|
||||
# 3. Mux frames + audio into an MP4.
|
||||
from server.video_models.muxer import frames_and_audio_to_mp4
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Agent guide — server/video_models/
|
||||
|
||||
Wrappers around 3rd-party video models. These are the trickiest files in the repo: LightX2V's internals move quickly upstream, and the Blackwell (RTX 5090 / SM120) GPU path requires several non-obvious patches layered on top. Read this before editing.
|
||||
|
||||
## Scope
|
||||
|
||||
- [wan22.py](wan22.py) — LightX2V Wan2.2-I2V A14B MoE pipeline. Supports fp8 safetensors and GGUF DIT checkpoints. Loaded once at startup, held resident; per-turn calls go through `generate_i2v` and `switch_lora`.
|
||||
- [musetalk.py](musetalk.py) — MuseTalk lip-sync over base frames + TTS audio.
|
||||
- [muxer.py](muxer.py) — thin ffmpeg wrappers: frames → MP4 loop, frames + audio → MP4.
|
||||
|
||||
Nothing here is imported unless `config.video.enabled` is true.
|
||||
|
||||
## LightX2V entry points (upstream API)
|
||||
|
||||
Use these symbols, not internal/private ones:
|
||||
|
||||
```python
|
||||
from lightx2v.utils.set_config import set_config
|
||||
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||
from lightx2v.infer import init_runner
|
||||
|
||||
config = set_config(args) # args is an argparse.Namespace
|
||||
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||
runner = init_runner(config) # loads all weights — ONCE
|
||||
|
||||
update_input_info_from_dict(input_info, {...}) # per-turn inputs
|
||||
runner.run_pipeline(input_info) # MP4 written to save_result_path
|
||||
runner.switch_lora(lora_path, strength) # hot-swap
|
||||
```
|
||||
|
||||
Keep model load out of the per-turn path — `init_runner` is expensive.
|
||||
|
||||
## Blackwell (SM120) patches — do not remove without testing
|
||||
|
||||
The GGUF pipeline works on a 5090 only because of layered patches in `wan22.py` and tuning in the LightX2V JSON configs under [configs/lightx2v/](../../configs/lightx2v/). Each patch exists because a stock upstream path segfaults or silently miscomputes on SM120.
|
||||
|
||||
**Dtype plumbing (GGUF path):**
|
||||
|
||||
- Default `DTYPE` must be `BF16` at `init_runner()` time — T5 offload buffers break if FP16 at init.
|
||||
- Flip `BF16 → FP16` *after* `init_runner()`.
|
||||
- Wrap T5 encoder so it runs under BF16 internally, then cast outputs `bf16 → fp16` before handing to the DIT. See `_patch_t5_dtype_for_gguf`.
|
||||
- Cast VAE **both** layers: the inner `.model` via `.to(fp16)` **and** the outer `WanVAE` wrapper's `mean` / `inv_std` / `scale` tensors. Missing the wrapper tensors upcasts the latent during decode's `z/inv_std + mean`.
|
||||
- DIT `pre_weight.patch_embedding.pin_weight` loads as fp32 (only `pin_bias` is fp16). Cast **and** re-pin via `.pin_memory()` — skipping re-pin segfaults during `to_cuda` H2D copy.
|
||||
- `sgl_kernel`'s fp8 scaled matmul is patched to `torch._scaled_mm` in `_patch_fp8_scaled_mm_for_blackwell`.
|
||||
|
||||
**LightX2V JSON config (see `wan22_i2v_gguf_distill.json`):**
|
||||
|
||||
- `modulate_type: "torch"` — Triton `fuse_scale_shift_kernel` segfaults in `ast_to_ttir` on Triton 3.4 + SM120.
|
||||
- `rope_type: "torch"` — flashinfer isn't installed.
|
||||
- `self_attn_1_type` / `cross_attn_*_type`: `"torch_sdpa"` — flash_attn3 unavailable; `sageattention==1.0.6` from PyPI segfaults on Blackwell (newer requires source build).
|
||||
|
||||
If you add a new quant scheme or a new model_cls, create its own JSON under `configs/lightx2v/` mirroring these choices, and exercise it end-to-end via a new `tests/component/test_NN_*.py` before wiring it into the default config.
|
||||
|
||||
## HF download layout
|
||||
|
||||
`Wan-AI/Wan2.2-I2V-A14B` ships ~28 GB of bf16 DIT shards we replace with the quantised `dit_repo`. `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](wan22.py) excludes them but **keeps** `high_noise_model/*.json` and `low_noise_model/*.json` — `set_config` parses architecture params (`dim`, etc.) from those. Don't broaden the ignore pattern without checking.
|
||||
|
||||
Supported quant schemes live in `wan22_dit_quant_scheme`:
|
||||
|
||||
- `fp8-sgl` — `lightx2v/Wan2.2-Distill-Models`, two `.safetensors` files
|
||||
- `gguf-Q4_K_M`, `gguf-Q8_0`, … — `QuantStack/Wan2.2-I2V-A14B-GGUF`, layout `HighNoise/…` and `LowNoise/…`
|
||||
|
||||
Filenames are templated at the top of `wan22.py`; update those if the upstream repos rename files.
|
||||
|
||||
## Testing
|
||||
|
||||
- `tests/component/test_02_wan22_loras.py` — full pipeline load + LoRA apply
|
||||
- `tests/component/test_09_gguf_generate.py` — GGUF end-to-end I2V
|
||||
- `tests/component/test_10_t5_encode.py` — T5 encoder dtype path
|
||||
- `tests/component/test_11_image_encode.py` — image → VAE latent
|
||||
- `tests/component/test_12_dit_single_step.py` — one DIT step per expert
|
||||
- `tests/component/test_13_vae_decode.py` — VAE decode → RGB
|
||||
|
||||
When diagnosing a Blackwell regression, run 10 → 11 → 12 → 13 in that order; the failure localises to the first failing stage.
|
||||
|
||||
## LoRAs
|
||||
|
||||
`switch_lora(path, strength)` applies; `switch_lora("", 0.0)` removes. `load_loras`/`unload_loras` in this wrapper iterate over `LoRASpec`s from config and call `switch_lora` per `target` sub-model (`high_noise`, `low_noise`, or `both`). Wrong target = silently wrong output.
|
||||
@@ -32,8 +32,13 @@ class MuseTalkEngine:
|
||||
def _load_impl(model_path: str):
|
||||
"""Load the MuseTalk inference implementation.
|
||||
|
||||
If none of the known entry points work the error message points at
|
||||
this file so you know where to fix it.
|
||||
Upstream MuseTalk has no library-style entry point — it's a bundle
|
||||
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
|
||||
``third_party/MuseTalk`` adds package metadata but the low-level
|
||||
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
|
||||
modules. We import them here to verify the install succeeded; the
|
||||
actual pipeline (VAE, UNet, Whisper, face detection, blending)
|
||||
is wired up inside ``MuseTalkEngine.lip_sync``.
|
||||
"""
|
||||
resolved = model_path
|
||||
if not os.path.isdir(model_path) and "/" in model_path:
|
||||
@@ -43,28 +48,19 @@ class MuseTalkEngine:
|
||||
except Exception as e: # pragma: no cover
|
||||
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
||||
|
||||
# Try upstream MuseTalk repo layout.
|
||||
try:
|
||||
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
|
||||
return MuseTalkInference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
|
||||
return MuseTalkInfer(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk import Inference # type: ignore[import-not-found]
|
||||
return Inference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
|
||||
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MuseTalk Python package is not importable. "
|
||||
"Check that third_party/MuseTalk was installed via "
|
||||
"`pip install /opt/MuseTalk` in the Dockerfile."
|
||||
) from e
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk is installed but no known Python entry point was found. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
|
||||
"to match the installed MuseTalk version."
|
||||
)
|
||||
# Return the resolved weight path; lip_sync loads models lazily on
|
||||
# first call so import-time failures don't block voice-only startup.
|
||||
return {"model_path": resolved, "loaded": False}
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
@@ -98,31 +94,22 @@ class MuseTalkEngine:
|
||||
if target_t > 0 and len(frames) != target_t:
|
||||
frames = _fit_frames_to_length(frames, target_t)
|
||||
|
||||
# The real MuseTalk call signature varies. Most common is a method
|
||||
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
|
||||
for method_name in ("run", "infer", "lip_sync", "__call__"):
|
||||
method = getattr(self._infer, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
try:
|
||||
result = method(
|
||||
frames=frames,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
fps=fps,
|
||||
)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
# Try positional
|
||||
try:
|
||||
result = method(frames, audio, sample_rate, fps)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk wrapper could not find a working inference method. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
|
||||
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
|
||||
# realtime_inference.py::Avatar.inference) needs:
|
||||
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
|
||||
# - face_alignment (bbox)
|
||||
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
|
||||
# - Whisper encoder (openai/whisper-tiny)
|
||||
# - face_parsing weights
|
||||
# Plus its preprocessing module has import-time side effects that
|
||||
# load dwpose weights from a CWD-relative path. Turn the full
|
||||
# pipeline on by extending this method once those deps are
|
||||
# installed and weights are resolved — until then, callers should
|
||||
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
|
||||
# will skip the lip-sync pass.
|
||||
raise NotImplementedError(
|
||||
"MuseTalk lip-sync pipeline is not wired up yet. "
|
||||
"Set config.video.musetalk.enabled=false to bypass."
|
||||
)
|
||||
|
||||
|
||||
|
||||
+145
-187
@@ -1,4 +1,4 @@
|
||||
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
|
||||
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
@@ -22,15 +22,12 @@ Model weights are loaded once at construction and held resident across turns
|
||||
so reflective mode doesn't re-pay the load cost each reply.
|
||||
|
||||
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with
|
||||
quantised checkpoints from dit_repo.
|
||||
- dit_repo (configurable) — quantised DIT checkpoints. Supported
|
||||
formats:
|
||||
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
|
||||
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards are SKIPPED via
|
||||
ignore_patterns — replaced by the GGUF
|
||||
checkpoint from dit_repo.
|
||||
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
|
||||
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -49,27 +46,20 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# --- fp8 distill filenames --------------------------------------------------
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
|
||||
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
|
||||
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
|
||||
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
|
||||
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
|
||||
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
|
||||
|
||||
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
|
||||
# quantised files from dit_repo replace the DIT weights entirely. We must
|
||||
# keep the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
|
||||
# tokenizer support files. We only need the latter — the GGUF from dit_repo
|
||||
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
"high_noise_model/*.safetensors",
|
||||
"low_noise_model/*.safetensors",
|
||||
"*.pt",
|
||||
"diffusion_pytorch_model*.safetensors",
|
||||
"assets/*",
|
||||
"examples/*",
|
||||
"nohup.out",
|
||||
@@ -77,6 +67,40 @@ BASE_REPO_IGNORE_PATTERNS = [
|
||||
]
|
||||
|
||||
|
||||
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
|
||||
|
||||
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
|
||||
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
|
||||
and are missed by the structured sweep. This generic traversal catches
|
||||
them. Bounded depth + visited-set to avoid cycles.
|
||||
"""
|
||||
import torch
|
||||
if visited is None:
|
||||
visited = set()
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited or depth > 6:
|
||||
return 0
|
||||
visited.add(obj_id)
|
||||
n = 0
|
||||
for attr_name in dir(obj):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(obj, attr_name)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||
try:
|
||||
setattr(obj, attr_name, val.to(torch.float16))
|
||||
n += 1
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(val, "__dict__") and not callable(val):
|
||||
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||
return n
|
||||
|
||||
|
||||
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||
|
||||
@@ -132,13 +156,11 @@ def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
|
||||
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
|
||||
|
||||
Supports two DIT quantisation formats:
|
||||
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
|
||||
``lightx2v/Wan2.2-Distill-Models``)
|
||||
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
|
||||
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
|
||||
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
|
||||
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
|
||||
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
@@ -150,11 +172,11 @@ class Wan22Pipeline:
|
||||
base_repo: str,
|
||||
dit_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
model_cls: str = "wan2.2",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
dit_quant_scheme: str = "fp8-sgl",
|
||||
t5_quantized: bool = False,
|
||||
dit_quant_scheme: str = "gguf-Q8_0",
|
||||
t5_quantized: bool = True,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.dit_repo = dit_repo
|
||||
@@ -167,10 +189,15 @@ class Wan22Pipeline:
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
if not self._is_gguf:
|
||||
raise ValueError(
|
||||
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
|
||||
)
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
|
||||
self._dit_ckpt = self._ensure_dit_checkpoint(
|
||||
dit_repo, dit_quant_scheme,
|
||||
)
|
||||
self._t5_fp8_ckpt = (
|
||||
@@ -306,52 +333,53 @@ class Wan22Pipeline:
|
||||
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||
|
||||
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||
"""Cast leftover fp32 DIT pre/post weights to fp16.
|
||||
"""Cast leftover fp32 DIT weights to fp16 (dense model).
|
||||
|
||||
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
|
||||
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
|
||||
up loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
|
||||
runs uniformly in fp16.
|
||||
GGUF dequantises the transformer blocks to fp16, but a handful of
|
||||
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
|
||||
loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
|
||||
directly at ``runner.model`` (no MoE wrapper). After the structured
|
||||
pre/post weight sweep, we also run a recursive traversal to catch
|
||||
fp32 conv3d biases etc. that live outside pre/post_weight.
|
||||
"""
|
||||
import torch
|
||||
|
||||
runner = self._runner
|
||||
models = getattr(runner.model, "model", None)
|
||||
if models is None:
|
||||
return
|
||||
if not isinstance(models, (list, tuple)):
|
||||
models = [models]
|
||||
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
|
||||
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||
log.info(
|
||||
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
|
||||
n_struct, n_extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cast_fp32_dit_weights_in_model(m) -> int:
|
||||
import torch
|
||||
n_cast = 0
|
||||
for m in models:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
# Preserve pinned-memory status on pin_* tensors so
|
||||
# move_attr_to_cuda's non-blocking H2D copy is safe.
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
return n_cast
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@@ -374,46 +402,34 @@ class Wan22Pipeline:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_dit_checkpoints(
|
||||
def _ensure_dit_checkpoint(
|
||||
dit_repo: str,
|
||||
dit_quant_scheme: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the DIT pair.
|
||||
|
||||
Supports both fp8 safetensors and GGUF formats.
|
||||
"""
|
||||
) -> str:
|
||||
"""Return the local path to the single dense GGUF DIT checkpoint."""
|
||||
if not dit_repo:
|
||||
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||
if not dit_quant_scheme.startswith("gguf-"):
|
||||
raise ValueError(
|
||||
f"Only GGUF quant schemes are supported for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r})."
|
||||
)
|
||||
|
||||
is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
|
||||
|
||||
if is_gguf:
|
||||
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
|
||||
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
|
||||
else:
|
||||
high_file = FP8_HIGH_NOISE_FILE
|
||||
low_file = FP8_LOW_NOISE_FILE
|
||||
|
||||
# Local directory?
|
||||
if os.path.isdir(dit_repo):
|
||||
high = os.path.join(dit_repo, high_file)
|
||||
low = os.path.join(dit_repo, low_file)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
path = os.path.join(dit_repo, filename)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(
|
||||
f"DIT checkpoints not found in {dit_repo}: expected "
|
||||
f"{high_file} and {low_file}"
|
||||
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
|
||||
)
|
||||
return high, low
|
||||
return path
|
||||
|
||||
# HuggingFace download.
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading %s DIT checkpoints from %s ...",
|
||||
log.info("Downloading %s DIT checkpoint from %s ...",
|
||||
dit_quant_scheme, dit_repo)
|
||||
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
|
||||
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
|
||||
return high, low
|
||||
return hf_hub_download(repo_id=dit_repo, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_t5_fp8() -> str:
|
||||
@@ -431,8 +447,7 @@ class Wan22Pipeline:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._dit_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._dit_low
|
||||
cfg["dit_quantized_ckpt"] = self._dit_ckpt
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
# T5 fp8 quantization.
|
||||
@@ -496,103 +511,46 @@ class Wan22Pipeline:
|
||||
# --- LoRA --------------------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
|
||||
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
|
||||
|
||||
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
|
||||
to route the LoRA to the correct expert.
|
||||
|
||||
With ``lazy_load`` the DIT models are ``None`` at this point, so
|
||||
runtime ``switch_lora`` is impossible. Instead we inject
|
||||
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
|
||||
the LoRAs are applied when the models materialise on first inference.
|
||||
|
||||
Without ``lazy_load`` (models already resident) we call
|
||||
``switch_lora`` with explicit high/low keyword args.
|
||||
Dense has a single DIT (no MoE experts), so ``target`` must be
|
||||
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
|
||||
so ``switch_lora`` would crash — we use the dynamic-apply path that
|
||||
merges LoRAs during GGUF dequant.
|
||||
"""
|
||||
if not specs:
|
||||
return
|
||||
|
||||
# Resolve every path up-front (may trigger HF download).
|
||||
resolved: list[tuple["LoRASpec", str]] = []
|
||||
for spec in specs:
|
||||
if spec.target != "both":
|
||||
raise ValueError(
|
||||
f"Dense 5B Turbo has a single DIT; LoRA target must be "
|
||||
f"'both' (got {spec.target!r})."
|
||||
)
|
||||
local_path = self._resolve_lora_path(spec.path)
|
||||
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
|
||||
spec.name or spec.path, spec.weight, spec.target,
|
||||
local_path)
|
||||
log.info(" LoRA %s → strength=%.2f (%s)",
|
||||
spec.name or spec.path, spec.weight, local_path)
|
||||
resolved.append((spec, local_path))
|
||||
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
# Build the lora_configs list that LightX2V's lazy-load path
|
||||
# reads inside MultiDistillModelStruct.infer().
|
||||
lora_cfgs = []
|
||||
for spec, local_path in resolved:
|
||||
# LightX2V expects name "high_noise_model" / "low_noise_model"
|
||||
cfg_name = {
|
||||
"high_noise": "high_noise_model",
|
||||
"low_noise": "low_noise_model",
|
||||
}.get(spec.target)
|
||||
if cfg_name is None:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
lora_cfgs.append({
|
||||
"name": cfg_name,
|
||||
"path": local_path,
|
||||
"strength": spec.weight,
|
||||
})
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
else:
|
||||
# Models are loaded — use runtime hot-swap.
|
||||
high_path = high_strength = None
|
||||
low_path = low_strength = None
|
||||
for spec, local_path in resolved:
|
||||
if spec.target == "high_noise":
|
||||
high_path, high_strength = local_path, spec.weight
|
||||
elif spec.target == "low_noise":
|
||||
low_path, low_strength = local_path, spec.weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
|
||||
kwargs: dict = {}
|
||||
if high_path is not None:
|
||||
kwargs["high_lora_path"] = high_path
|
||||
kwargs["high_lora_strength"] = high_strength
|
||||
if low_path is not None:
|
||||
kwargs["low_lora_path"] = low_path
|
||||
kwargs["low_lora_strength"] = low_strength
|
||||
ok = self._runner.switch_lora(**kwargs)
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"runner.switch_lora returned False. Check that your "
|
||||
"LightX2V build supports runtime LoRA updates for "
|
||||
f"{self.model_cls}.")
|
||||
|
||||
lora_cfgs = [
|
||||
{"path": local_path, "strength": spec.weight}
|
||||
for spec, local_path in resolved
|
||||
]
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
self._applied_loras = list(specs)
|
||||
|
||||
def unload_loras(self) -> None:
|
||||
"""Remove all currently applied LoRAs."""
|
||||
if not self._applied_loras:
|
||||
return
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
# If models were materialised, drop them so the next inference
|
||||
# recreates them without LoRAs.
|
||||
model_struct = getattr(self._runner, "model", None)
|
||||
if model_struct is not None and hasattr(model_struct, "model"):
|
||||
for i in range(len(model_struct.model)):
|
||||
model_struct.model[i] = None
|
||||
else:
|
||||
self._runner.switch_lora("", 0.0)
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
self._applied_loras = []
|
||||
|
||||
@staticmethod
|
||||
|
||||
+28
-8
@@ -9,25 +9,45 @@ python -m pytest tests/unit -v
|
||||
```
|
||||
|
||||
These exercise pure logic: config parsing, prompt derivation, LoRA spec
|
||||
parsing, frame-length fitting, library round-robin selection. They do not
|
||||
touch CUDA, Wan2.2, MuseTalk, or ffmpeg. Safe to run on Windows, outside
|
||||
Docker, without any models installed.
|
||||
parsing, frame-length fitting, library round-robin selection, the
|
||||
pipeline's video branch, and ffmpeg mux argument shaping. They do not
|
||||
touch CUDA, Wan2.2, MuseTalk, or a real ffmpeg binary. Safe to run on
|
||||
Windows, outside Docker, without any models installed.
|
||||
|
||||
Current unit files:
|
||||
|
||||
- `test_video_config.py` — `VideoConfig.from_dict` round-trip, LoRA target validation
|
||||
- `test_video_engine_logic.py` — prompt derivation, library cursor, frame fitting
|
||||
- `test_pipeline_video_branch.py` — pipeline takes the video path iff engine is ready
|
||||
- `test_musetalk_fit_frames.py` — frame-length adjustment to match audio duration
|
||||
- `test_muxer_ffmpeg.py` — ffmpeg command construction
|
||||
|
||||
## Component tests — slow, GPU-required, run inside Docker
|
||||
|
||||
Each script in `tests/component/` exercises one subsystem end-to-end against
|
||||
the real models. They are ordered to match the implementation phases:
|
||||
Each script in `tests/component/` exercises one subsystem end-to-end
|
||||
against the real models. The numbered prefix reflects the implementation
|
||||
phase each script gates, and also serves as a reasonable run order when
|
||||
debugging a fresh environment:
|
||||
|
||||
| Script | Phase | Tests |
|
||||
|---|---|---|
|
||||
| `test_01_video_skeleton.py` | 1 | VideoEngine loads, config gate respected |
|
||||
| `test_02_wan22_loras.py` | 2 | Wan2.2 pipeline loads, LoRA stack applies |
|
||||
| `test_03_idle_clip.py` | 3 | set_avatar → idle MP4, written to disk for eyeballing |
|
||||
| `test_03_idle_clip.py` | 3 | `set_avatar` → idle MP4, written to disk for eyeballing |
|
||||
| `test_04_library_prebake.py` | 4 | library mode pre-bakes N base clips |
|
||||
| `test_05_musetalk_lipsync.py` | 5 | MuseTalk lip-sync on library frames + ffmpeg mux |
|
||||
| `test_06_reflective.py` | 6 | reflective mode: fresh Wan2.2 per reply |
|
||||
| `test_07_endpoints.py` | 7 | HTTP endpoints return sane responses |
|
||||
| `test_08_lora_reload.py` | 8 | /api/reload-loras swaps LoRAs live |
|
||||
| `test_08_lora_reload.py` | 8 | `/api/reload-loras` swaps LoRAs live |
|
||||
| `test_09_gguf_generate.py` | 9 | GGUF-quantised DIT end-to-end I2V generation |
|
||||
| `test_10_t5_encode.py` | 10 | T5 encoder (optionally fp8-quantised) on CUDA |
|
||||
| `test_11_image_encode.py` | 11 | Avatar image → VAE latent path |
|
||||
| `test_12_dit_single_step.py` | 12 | Single DIT step on the loaded expert(s) |
|
||||
| `test_13_vae_decode.py` | 13 | VAE decode back to RGB frames |
|
||||
|
||||
Tests 09-13 are focused on the GGUF + Blackwell (SM120) path and are how
|
||||
new quant schemes / attention backends get validated before wiring them
|
||||
into the full pipeline.
|
||||
|
||||
Run one:
|
||||
|
||||
@@ -36,7 +56,7 @@ Run one:
|
||||
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
|
||||
```
|
||||
|
||||
Run all (slow, ~20+ minutes on 5090):
|
||||
Run all (slow, ~20+ minutes on a 5090):
|
||||
|
||||
```
|
||||
docker compose exec voice-chat python -m tests.component.run_all
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
"""Phase 2 component test: Wan2.2 pipeline + LoRA stacking.
|
||||
"""Phase 2 component test: dense Wan2.2-TI2V-5B-Turbo pipeline + LoRA stacking.
|
||||
|
||||
Verifies:
|
||||
- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
|
||||
set_config -> init_runner flow).
|
||||
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
|
||||
``/cache/loras/wan22-[HL]-e8.safetensors``.
|
||||
- ``load_loras`` / ``unload_loras`` survive with any user LoRAs at
|
||||
``/cache/loras/*.safetensors`` (target='both', dense single DIT).
|
||||
|
||||
Supports both fp8 and GGUF DIT quantisation. Set the ``DIT_QUANT``
|
||||
environment variable to switch (default: ``fp8-sgl``).
|
||||
Supports any GGUF quant published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||
Set ``DIT_QUANT`` to switch (default: ``gguf-Q8_0``).
|
||||
|
||||
DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
|
||||
python -m tests.component.test_02_wan22_loras
|
||||
|
||||
Requires GPU and a first-run download of both HF repos (base support files
|
||||
~12 GB, DIT size depends on quant — fp8 ~30 GB, GGUF Q4_K_M ~19 GB).
|
||||
Requires GPU and a first-run download of the base repo + GGUF DIT.
|
||||
If LightX2V isn't installed the test is skipped.
|
||||
|
||||
Run (default fp8):
|
||||
Run:
|
||||
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -28,19 +28,9 @@ from tests.component._common import get_logger
|
||||
|
||||
log = get_logger("test_02")
|
||||
|
||||
# --- Quant-dependent defaults ------------------------------------------------
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "fp8-sgl")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
|
||||
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
|
||||
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
|
||||
|
||||
def run():
|
||||
@@ -57,13 +47,14 @@ def run():
|
||||
"(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
|
||||
try:
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
model_cls="wan2.2",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
t5_quantized=True,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
|
||||
@@ -78,34 +69,27 @@ def run():
|
||||
pipe.load_loras([])
|
||||
log.info(" PASS")
|
||||
|
||||
if not (os.path.isfile(LORA_HIGH) and os.path.isfile(LORA_LOW)):
|
||||
log.warning("SKIP: expected LoRA files not found at %s / %s",
|
||||
LORA_HIGH, LORA_LOW)
|
||||
lora_files = sorted(glob.glob("/cache/loras/*.safetensors"))
|
||||
if not lora_files:
|
||||
log.warning("SKIP: no LoRA files found in /cache/loras/")
|
||||
log.info("ALL PASSED (partial — LoRA cases skipped)")
|
||||
return
|
||||
|
||||
log.info("[case 3] load_loras with the two MoE distill LoRAs")
|
||||
lora_path = lora_files[0]
|
||||
log.info("[case 3] load_loras with one 5B-compatible LoRA (%s)", lora_path)
|
||||
specs = [
|
||||
LoRASpec(
|
||||
path=LORA_HIGH,
|
||||
path=lora_path,
|
||||
weight=1.0,
|
||||
target="high_noise",
|
||||
name="wan22-H-e8",
|
||||
),
|
||||
LoRASpec(
|
||||
path=LORA_LOW,
|
||||
weight=1.0,
|
||||
target="low_noise",
|
||||
name="wan22-L-e8",
|
||||
target="both",
|
||||
name=os.path.basename(lora_path),
|
||||
),
|
||||
]
|
||||
try:
|
||||
pipe.load_loras(specs)
|
||||
except Exception as e:
|
||||
log.error("FAIL: load_loras raised: %s", e)
|
||||
log.error("Check: switch_lora support for wan2.2_moe_distill in the "
|
||||
"installed LightX2V build. If it errors there, pre-declare "
|
||||
"LoRAs in the config_json 'lora_configs' field instead.")
|
||||
log.error("Check: LoRA checkpoint shape matches dense 5B DIT.")
|
||||
sys.exit(3)
|
||||
log.info(" PASS: LoRAs applied")
|
||||
|
||||
|
||||
@@ -81,9 +81,9 @@ def run():
|
||||
body = {
|
||||
"loras": [
|
||||
{"path": "/cache/loras/a.safetensors", "weight": 0.8,
|
||||
"target": "high_noise", "name": "test-a"},
|
||||
"target": "both", "name": "test-a"},
|
||||
{"path": "/cache/loras/b.safetensors", "weight": 0.4,
|
||||
"target": "low_noise"},
|
||||
"target": "both"},
|
||||
]
|
||||
}
|
||||
resp = client.post("/api/reload-loras", json=body)
|
||||
|
||||
@@ -32,28 +32,20 @@ def run():
|
||||
write_bytes("phase8_idle_noloras.mp4", idle_a)
|
||||
log.info("idle (no LoRAs) sha256=%s", hash_a[:16])
|
||||
|
||||
# Hot-reload with a distill LoRA
|
||||
specs = [
|
||||
LoRASpec(
|
||||
path="lightx2v/Wan2.2-Distill-Loras:"
|
||||
"wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors",
|
||||
weight=1.0,
|
||||
target="high_noise",
|
||||
name="distill-hi",
|
||||
),
|
||||
]
|
||||
engine.load_loras(specs)
|
||||
# Hot-reload flow: unload (no-op), reload empty list, verify clip still generates.
|
||||
# There are no published 5B-Turbo-compatible LoRAs yet; when one exists,
|
||||
# construct a LoRASpec(path=..., target="both", weight=1.0) and compare hashes.
|
||||
engine.load_loras([])
|
||||
engine.set_avatar(avatar_path)
|
||||
idle_b = engine.get_idle_clip()
|
||||
assert idle_b is not None
|
||||
hash_b = hashlib.sha256(idle_b).hexdigest()
|
||||
write_bytes("phase8_idle_withlora.mp4", idle_b)
|
||||
log.info("idle (with LoRA) sha256=%s", hash_b[:16])
|
||||
write_bytes("phase8_idle_reloaded.mp4", idle_b)
|
||||
log.info("idle (post-reload) sha256=%s", hash_b[:16])
|
||||
|
||||
if hash_a != hash_b:
|
||||
log.info("PASS: idle clip changed after LoRA reload")
|
||||
else:
|
||||
log.warning("clips identical — LoRA may not be applied; eyeball _out/*.mp4")
|
||||
log.info("PASS: hot-reload round-trip completed "
|
||||
"(hash match=%s — expected without a real LoRA applied).",
|
||||
hash_a == hash_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Quick smoke test: generate a video clip with the GGUF pipeline.
|
||||
"""Quick smoke test: generate a video clip with the dense 5B Turbo GGUF pipeline.
|
||||
|
||||
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
|
||||
and writes the result to tests/component/_out/phase9_gguf.mp4.
|
||||
|
||||
Run:
|
||||
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||
docker compose exec -e DIT_QUANT=gguf-Q8_0 voice-chat \
|
||||
python -m tests.component.test_09_gguf_generate
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -16,14 +16,9 @@ from tests.component._common import ensure_sample_avatar, get_logger, write_byte
|
||||
|
||||
log = get_logger("test_09")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
|
||||
|
||||
def run():
|
||||
@@ -38,10 +33,10 @@ def run():
|
||||
|
||||
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
model_cls="wan2.2",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
|
||||
@@ -17,14 +17,9 @@ from tests.component._common import get_logger
|
||||
|
||||
log = get_logger("test_10")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
|
||||
|
||||
def run():
|
||||
@@ -36,10 +31,10 @@ def run():
|
||||
|
||||
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
model_cls="wan2.2",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
|
||||
@@ -22,14 +22,9 @@ from tests.component._common import ensure_sample_avatar, get_logger
|
||||
|
||||
log = get_logger("test_11")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
|
||||
|
||||
def run():
|
||||
@@ -44,10 +39,10 @@ def run():
|
||||
|
||||
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
model_cls="wan2.2",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
|
||||
@@ -20,14 +20,9 @@ from tests.component._common import ensure_sample_avatar, get_logger
|
||||
|
||||
log = get_logger("test_12")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
|
||||
|
||||
def run():
|
||||
@@ -42,10 +37,10 @@ def run():
|
||||
|
||||
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
model_cls="wan2.2",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
|
||||
@@ -19,14 +19,9 @@ from tests.component._common import ensure_sample_avatar, get_logger, write_byte
|
||||
|
||||
log = get_logger("test_13")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
|
||||
|
||||
def run():
|
||||
@@ -41,10 +36,10 @@ def run():
|
||||
|
||||
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
model_cls="wan2.2",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
|
||||
@@ -64,32 +64,39 @@ def test_lora_parse_full():
|
||||
{
|
||||
"loras": [
|
||||
{
|
||||
"path": "/tmp/hi.safetensors",
|
||||
"path": "/tmp/a.safetensors",
|
||||
"weight": 0.7,
|
||||
"target": "high_noise",
|
||||
"name": "hi-noise-style",
|
||||
"target": "both",
|
||||
"name": "style-a",
|
||||
},
|
||||
{
|
||||
"path": "/tmp/lo.safetensors",
|
||||
"path": "/tmp/b.safetensors",
|
||||
"weight": 0.4,
|
||||
"target": "low_noise",
|
||||
"name": "lo-noise-style",
|
||||
"target": "both",
|
||||
"name": "style-b",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
assert len(cfg.loras) == 2
|
||||
assert cfg.loras[0].target == "high_noise"
|
||||
assert cfg.loras[0].name == "hi-noise-style"
|
||||
assert cfg.loras[1].target == "low_noise"
|
||||
assert cfg.loras[0].target == "both"
|
||||
assert cfg.loras[0].name == "style-a"
|
||||
assert cfg.loras[1].target == "both"
|
||||
assert cfg.loras[1].weight == 0.4
|
||||
|
||||
|
||||
def test_lora_invalid_target_falls_back_to_both():
|
||||
def test_lora_legacy_moe_target_coerced_to_both():
|
||||
"""Legacy MoE configs with target='high_noise'/'low_noise' get coerced."""
|
||||
cfg = VideoConfig.from_dict(
|
||||
{"loras": [{"path": "/tmp/x.safetensors", "target": "bogus"}]}
|
||||
{
|
||||
"loras": [
|
||||
{"path": "/tmp/hi.safetensors", "target": "high_noise"},
|
||||
{"path": "/tmp/lo.safetensors", "target": "low_noise"},
|
||||
{"path": "/tmp/x.safetensors", "target": "bogus"},
|
||||
]
|
||||
}
|
||||
)
|
||||
assert cfg.loras[0].target == "both"
|
||||
assert all(l.target == "both" for l in cfg.loras)
|
||||
|
||||
|
||||
def test_lora_entries_without_path_are_dropped():
|
||||
@@ -107,8 +114,8 @@ def test_models_section_override():
|
||||
"wan22_base_repo": "/local/weights/wan22",
|
||||
"wan22_dit_repo": "/local/weights/wan22-dit",
|
||||
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
|
||||
"wan22_config_json": "/local/cfg/fp8.json",
|
||||
"wan22_model_cls": "wan2.2_moe",
|
||||
"wan22_config_json": "/local/cfg/turbo.json",
|
||||
"wan22_model_cls": "wan2.2",
|
||||
"musetalk_path": "/local/weights/musetalk",
|
||||
}
|
||||
}
|
||||
@@ -116,18 +123,16 @@ def test_models_section_override():
|
||||
assert cfg.wan22_base_repo == "/local/weights/wan22"
|
||||
assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
|
||||
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
|
||||
assert cfg.wan22_config_json == "/local/cfg/fp8.json"
|
||||
assert cfg.wan22_model_cls == "wan2.2_moe"
|
||||
assert cfg.wan22_config_json == "/local/cfg/turbo.json"
|
||||
assert cfg.wan22_model_cls == "wan2.2"
|
||||
assert cfg.musetalk_model_path == "/local/weights/musetalk"
|
||||
|
||||
|
||||
def test_models_section_backwards_compat_fp8_repo():
|
||||
"""Old config key wan22_fp8_repo still works via fallback."""
|
||||
cfg = VideoConfig.from_dict(
|
||||
{
|
||||
"models": {
|
||||
"wan22_fp8_repo": "/local/weights/wan22-fp8",
|
||||
}
|
||||
}
|
||||
)
|
||||
assert cfg.wan22_dit_repo == "/local/weights/wan22-fp8"
|
||||
def test_models_section_defaults_to_5b_turbo():
|
||||
cfg = VideoConfig.from_dict({})
|
||||
assert cfg.wan22_base_repo == "Wan-AI/Wan2.2-TI2V-5B"
|
||||
assert cfg.wan22_dit_repo == "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
assert cfg.wan22_dit_quant_scheme == "gguf-Q8_0"
|
||||
assert cfg.wan22_t5_quantized is True
|
||||
assert cfg.wan22_model_cls == "wan2.2"
|
||||
assert cfg.wan22_config_json == "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
|
||||
Reference in New Issue
Block a user