Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 129df7d1fa | |||
| 9debc56137 | |||
| 56923ff424 | |||
| fcf0be38bc | |||
| 2818b41004 |
@@ -1,3 +1,6 @@
|
|||||||
.venv
|
.venv
|
||||||
.claude
|
.claude
|
||||||
__pycache__
|
__pycache__
|
||||||
|
tests/component/_out/
|
||||||
|
avatars/
|
||||||
|
loras/
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "third_party/MuseTalk"]
|
||||||
|
path = third_party/MuseTalk
|
||||||
|
url = https://git.hetherman.cloud/bhetherman/MuseTalk.git
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
# Agent guide — voice-chat
|
||||||
|
|
||||||
|
Orientation for AI agents making changes in this repo. Keep edits scoped; this pipeline has several sharp edges that a naive refactor will hit.
|
||||||
|
|
||||||
|
## What this project is
|
||||||
|
|
||||||
|
A local real-time voice (and optionally video-avatar) chat server. FastAPI + WebSocket on the backend, a small vanilla-JS UI on the frontend. All ML runs locally on a single NVIDIA GPU; there are no cloud API calls at runtime.
|
||||||
|
|
||||||
|
## Two-tier architecture
|
||||||
|
|
||||||
|
1. **Audio pipeline** (always on): mic PCM → [vad.py](server/vad.py) → [asr.py](server/asr.py) → [llm.py](server/llm.py) → [tts.py](server/tts.py) → PCM back to the browser. Orchestrated by [pipeline.py](server/pipeline.py) via `ConversationSession`. Supports barge-in (user speaks → cancel in-flight reply).
|
||||||
|
|
||||||
|
2. **Video pipeline** (optional, gated by `config.video.enabled`): per assistant turn, [video.py](server/video.py)'s `VideoEngine` calls [video_models/wan22.py](server/video_models/wan22.py) (LightX2V Wan2.2-I2V) to produce base frames, then [video_models/musetalk.py](server/video_models/musetalk.py) to lip-sync them to the TTS audio, then [video_models/muxer.py](server/video_models/muxer.py) to produce an MP4. The MP4 is sent over the same `/ws/chat` WebSocket as a `speaking_clip` message.
|
||||||
|
|
||||||
|
The audio path must keep working when `video.enabled` is false. Don't make video models load-bearing for the audio pipeline.
|
||||||
|
|
||||||
|
## Key files
|
||||||
|
|
||||||
|
- [server/main.py](server/main.py) — FastAPI app, WebSocket, video/avatar HTTP endpoints
|
||||||
|
- [server/models.py](server/models.py) — lifetime of all models; `ModelManager.video_engine` is `None` when video is disabled
|
||||||
|
- [server/pipeline.py](server/pipeline.py) — per-session audio pipeline + video branch
|
||||||
|
- [server/config.py](server/config.py) — parses [config.yml](config.yml)
|
||||||
|
- [server/video.py](server/video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine` (library vs reflective modes)
|
||||||
|
- [server/video_models/wan22.py](server/video_models/wan22.py) — LightX2V wrapper; fp8 + GGUF loading; Blackwell patches
|
||||||
|
- [configs/lightx2v/](configs/lightx2v/) — LightX2V inference config templates; must match `wan22_dit_quant_scheme`
|
||||||
|
- [tests/unit/](tests/unit/) — GPU-free tests, runnable on Windows host
|
||||||
|
- [tests/component/](tests/component/) — end-to-end tests, must run inside the Docker container
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
|
||||||
|
- Config: single source of truth is [config.yml](config.yml) → `server/config.py` dataclasses. Don't read env vars for runtime behaviour; if you need a new knob, add it to the dataclass and document it in `config.yml`.
|
||||||
|
- Logging: `log = logging.getLogger(__name__)` at module top; log level is set once in `server/main.py`. INFO for lifecycle, DEBUG for per-chunk chatter.
|
||||||
|
- Async: WebSocket handlers and endpoints are async, but heavy model work is sync — wrap via `asyncio.to_thread(...)` at the call site (see `set_avatar` in `main.py`).
|
||||||
|
- Concurrency: `VideoEngine` serialises generation with `self._lock`. Don't call model methods without holding it from another thread.
|
||||||
|
- Tests: every non-trivial logic change in `server/video.py` or `server/pipeline.py` should have a corresponding `tests/unit/` test. GPU-dependent behaviour goes in `tests/component/`.
|
||||||
|
|
||||||
|
## Gotchas
|
||||||
|
|
||||||
|
- **GPU architecture.** This is tuned for RTX 5090 / Blackwell (SM120) with PyTorch 2.8 + Triton 3.4. Several upstream kernels (flashinfer, flash_attn3, sgl_kernel fp8 matmul, Triton-fused scale/shift) are broken or unavailable there. See [server/video_models/AGENT.md](server/video_models/AGENT.md) before touching the Wan2.2 wrapper.
|
||||||
|
- **First launch is slow.** Hugging Face downloads land in the `huggingface-cache` Docker volume; a cold run pulls >20 GB.
|
||||||
|
- **Wan-AI/Wan2.2-I2V-A14B** ships bf16 DIT shards we don't want — `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](server/video_models/wan22.py) excludes them. Keep that list in sync if the repo layout changes.
|
||||||
|
- **LoRA targets matter.** Wan2.2 is a MoE (high_noise + low_noise sub-models). A LoRA with the wrong `target` loads silently and produces subtly wrong output.
|
||||||
|
- **Don't mix audio+video state.** The audio pipeline must not block on video generation; video is produced for a turn *after* the full reply audio is available, and sent as a separate message.
|
||||||
|
|
||||||
|
## When in doubt
|
||||||
|
|
||||||
|
- Run `python -m pytest tests/unit -v` — it's fast and catches most regressions.
|
||||||
|
- For GPU changes, run the lowest-numbered relevant component test first; they're ordered to isolate failure to a single stage.
|
||||||
|
- Check [memory](../../.claude/projects/c--Users-bheth-Documents-voice-chat/memory/) (auto-loaded) for prior-session findings that aren't in the code.
|
||||||
+40
@@ -4,6 +4,9 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
# HuggingFace model cache — mounted as a volume so models persist across runs
|
# HuggingFace model cache — mounted as a volume so models persist across runs
|
||||||
ENV HF_HOME=/cache/huggingface
|
ENV HF_HOME=/cache/huggingface
|
||||||
|
# LoRA directory — users drop .safetensors files here and reference them
|
||||||
|
# from config.yml::video.loras. Bind-mounted via docker-compose.
|
||||||
|
ENV LORA_DIR=/cache/loras
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.11 \
|
python3.11 \
|
||||||
@@ -38,6 +41,43 @@ RUN python3.11 -m pip install --no-cache-dir -r requirements.txt
|
|||||||
# Pre-download the spacy model that kokoro needs at runtime
|
# Pre-download the spacy model that kokoro needs at runtime
|
||||||
RUN python3.11 -m spacy download en_core_web_sm
|
RUN python3.11 -m spacy download en_core_web_sm
|
||||||
|
|
||||||
|
# --- Optional: avatar video stack -------------------------------------------
|
||||||
|
# These are heavy installs; keep them after the core deps so rebuilds only
|
||||||
|
# redo this layer when ONLY the video stack changes. If you don't plan to
|
||||||
|
# use config.video.enabled=true, you can comment this block out to speed
|
||||||
|
# up builds and shrink the image.
|
||||||
|
#
|
||||||
|
# LightX2V (Wan2.2-Lightning inference framework) — installed from source
|
||||||
|
# since there is no stable PyPI release yet.
|
||||||
|
RUN python3.11 -m pip install --no-cache-dir \
|
||||||
|
"git+https://github.com/ModelTC/LightX2V.git" || \
|
||||||
|
echo "LightX2V install failed — config.video.enabled must stay false until fixed"
|
||||||
|
#
|
||||||
|
# sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120
|
||||||
|
# (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead.
|
||||||
|
# Our wan22.py patches fp8_scaled_mm → torch._scaled_mm at runtime for
|
||||||
|
# Blackwell GPUs, but the sgl_kernel package itself must still be present.
|
||||||
|
RUN python3.11 -m pip install --no-cache-dir --no-deps \
|
||||||
|
"sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \
|
||||||
|
echo "sgl-kernel install failed — fp8 T5 will fall back to bf16"
|
||||||
|
#
|
||||||
|
# MuseTalk (audio-driven lip-sync) — installed from the bhetherman/MuseTalk
|
||||||
|
# fork checked in as a submodule at third_party/MuseTalk. The upstream repo
|
||||||
|
# has no setup.py / pyproject.toml; our fork adds them so `pip install .`
|
||||||
|
# just works. We deliberately do NOT install its requirements.txt (it pins
|
||||||
|
# numpy==1.23.5, transformers==4.39.2, tensorflow==2.12.0 which conflict
|
||||||
|
# with the rest of the stack) — instead we install its real runtime deps
|
||||||
|
# explicitly here.
|
||||||
|
COPY third_party/MuseTalk /opt/MuseTalk
|
||||||
|
RUN python3.11 -m pip install --no-cache-dir --no-deps /opt/MuseTalk || \
|
||||||
|
echo "MuseTalk install failed — config.video.musetalk.enabled must stay false until fixed"
|
||||||
|
RUN python3.11 -m pip install --no-cache-dir \
|
||||||
|
librosa einops omegaconf ffmpeg-python || \
|
||||||
|
echo "MuseTalk runtime deps install failed"
|
||||||
|
#
|
||||||
|
# LoRA directory (user drops .safetensors here; bind-mounted in compose).
|
||||||
|
RUN mkdir -p /cache/loras
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|||||||
@@ -1,21 +1,29 @@
|
|||||||
# Voice Chat
|
# Voice Chat
|
||||||
|
|
||||||
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — all running on your own GPU with no cloud APIs.
|
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — and optionally a lip-synced talking-head video of a chosen avatar — all running on your own GPU with no cloud APIs.
|
||||||
|
|
||||||
## Pipeline
|
## Pipeline
|
||||||
|
|
||||||
**Mic input** → **VAD** (Silero ONNX) → **ASR** (Qwen3-ASR-0.6B) → **LLM** (Qwen3.5-0.8B) → **TTS** (Kokoro) → **Speaker output**
|
**Mic input** → **VAD** (Silero ONNX) → **ASR** (Qwen3-ASR-0.6B) → **LLM** (Qwen3.5-0.8B) → **TTS** (Kokoro) → **Speaker output**
|
||||||
|
|
||||||
|
When the optional video stack is enabled, each assistant turn also produces an MP4 via:
|
||||||
|
|
||||||
|
**TTS audio + avatar image** → **Wan2.2-Lightning I2V** (LightX2V, fp8 or GGUF) → **MuseTalk lip-sync** → **ffmpeg mux** → **`speaking_clip` WebSocket message**
|
||||||
|
|
||||||
- **VAD** — Silero VAD via ONNX Runtime, detects speech/silence boundaries on CPU
|
- **VAD** — Silero VAD via ONNX Runtime, detects speech/silence boundaries on CPU
|
||||||
- **ASR** — Qwen3-ASR-0.6B, bfloat16 on CUDA
|
- **ASR** — Qwen3-ASR-0.6B, bfloat16 on CUDA
|
||||||
- **LLM** — Qwen3.5-0.8B, loaded via transformers
|
- **LLM** — Qwen3.5-0.8B (local) or any model served by LM Studio
|
||||||
- **TTS** — Kokoro, streams sentence-by-sentence audio at 24 kHz
|
- **TTS** — Kokoro, streams sentence-by-sentence audio at 24 kHz
|
||||||
- **Barge-in** — interrupt the assistant mid-response by speaking
|
- **Barge-in** — interrupt the assistant mid-response by speaking
|
||||||
|
- **Video (optional)** — Wan2.2 I2V 14B MoE with LightX2V distill LoRAs, fp8 or GGUF-quantised DIT, lip-synced by MuseTalk. Two modes:
|
||||||
|
- `library` — pre-bakes a small set of speaking base clips per avatar, picks round-robin per turn, lip-syncs on the fly
|
||||||
|
- `reflective` — generates a fresh Wan2.2 clip per turn from a prompt derived from the assistant's reply
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- NVIDIA GPU with CUDA 12.8 support
|
- NVIDIA GPU with CUDA 12.8 support (tested on RTX 5090 / SM120 Blackwell)
|
||||||
- Docker + Docker Compose with the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
- Docker + Docker Compose with the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
||||||
|
- ~24 GB VRAM recommended when video is enabled (fp8); ~16 GB with `gguf-Q4_K_M`
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -27,6 +35,17 @@ Then open [http://localhost:8000](http://localhost:8000) in your browser.
|
|||||||
|
|
||||||
Models are downloaded from Hugging Face on first launch and cached in a Docker volume (`huggingface-cache`) so they persist across rebuilds.
|
Models are downloaded from Hugging Face on first launch and cached in a Docker volume (`huggingface-cache`) so they persist across rebuilds.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
All runtime behaviour is driven by [config.yml](config.yml):
|
||||||
|
|
||||||
|
- `llm.backend` — `local` (in-process transformers) or `lmstudio` (talks to a local LM Studio server)
|
||||||
|
- `llm.system_prompt`, `llm.max_cache_tokens` — prompt and KV-cache limit per session
|
||||||
|
- `video.enabled` — master toggle for the avatar video stack. When `false`, no video models load and the app behaves exactly like the audio-only pipeline
|
||||||
|
- `video.mode` — `library` or `reflective`
|
||||||
|
- `video.models.wan22_dit_quant_scheme` — `gguf-Q8_0` (default) or `gguf-Q4_K_M` for lower VRAM; any GGUF level LightX2V supports (dense 5B Turbo is GGUF-only)
|
||||||
|
- `video.loras` — list of LoRA adapters applied to the dense Wan2.2-TI2V-5B DIT at load time. Each entry has `path`, `weight`, `target` (always `both` — the 5B DIT is not MoE; legacy `high_noise`/`low_noise` values are coerced), and optional `name`. User LoRAs are mounted from `./loras/` into the container at `/cache/loras/`
|
||||||
|
|
||||||
## Local Development (without Docker)
|
## Local Development (without Docker)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -45,17 +64,41 @@ python run.py
|
|||||||
|
|
||||||
The server starts on port 8000.
|
The server starts on port 8000.
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
- `GET /` — browser UI
|
||||||
|
- `WS /ws/chat` — bidirectional audio + control WebSocket
|
||||||
|
- `POST /api/set-voice` — switch the Kokoro voice
|
||||||
|
- `POST /api/set-avatar` *(video only)* — upload an avatar image; (re)generates idle + library clips
|
||||||
|
- `GET /api/idle-clip` *(video only)* — cached idle loop MP4
|
||||||
|
- `POST /api/set-video-mode` *(video only)* — switch between `off` / `library` / `reflective`
|
||||||
|
- `POST /api/reload-loras` *(video only)* — hot-swap the LoRA stack; regenerates the idle clip
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
server/
|
server/
|
||||||
main.py — FastAPI app, WebSocket endpoint
|
main.py — FastAPI app, WebSocket + video endpoints
|
||||||
models.py — Model loading and management
|
models.py — Model loading and management (audio + optional video)
|
||||||
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration
|
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration, video branch
|
||||||
|
config.py — config.yml parsing
|
||||||
vad.py — Silero VAD (ONNX) streaming wrapper
|
vad.py — Silero VAD (ONNX) streaming wrapper
|
||||||
asr.py — Speech recognition engine
|
asr.py — Speech recognition engine
|
||||||
llm.py — Language model engine
|
llm.py — Language model engine (local + LM Studio backends)
|
||||||
tts.py — Kokoro TTS engine
|
tts.py — Kokoro TTS engine
|
||||||
audio_utils.py — PCM/float32 conversion helpers
|
audio_utils.py — PCM/float32 conversion helpers
|
||||||
|
video.py — VideoEngine orchestrator + VideoConfig + LoRASpec
|
||||||
|
video_models/
|
||||||
|
wan22.py — LightX2V Wan2.2 I2V pipeline wrapper (fp8 + GGUF)
|
||||||
|
musetalk.py — MuseTalk lip-sync wrapper
|
||||||
|
muxer.py — ffmpeg helpers: frames -> MP4, frames+audio -> MP4
|
||||||
|
|
||||||
|
configs/
|
||||||
|
lightx2v/ — LightX2V inference config templates (fp8 + GGUF variants)
|
||||||
|
|
||||||
static/ — Browser UI (HTML/JS/CSS)
|
static/ — Browser UI (HTML/JS/CSS)
|
||||||
|
avatars/ — uploaded avatar images (gitignored)
|
||||||
|
loras/ — user-supplied Wan2.2 LoRAs, mounted into the container
|
||||||
|
reference_audio/ — Kokoro voice reference samples
|
||||||
|
tests/ — unit + component tests (see tests/README.md)
|
||||||
```
|
```
|
||||||
|
|||||||
+57
@@ -12,3 +12,60 @@ llm:
|
|||||||
lmstudio:
|
lmstudio:
|
||||||
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
|
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
|
||||||
model: "" # leave empty to use whatever model LM Studio has loaded
|
model: "" # leave empty to use whatever model LM Studio has loaded
|
||||||
|
|
||||||
|
# Avatar video generation (Wan2.2-TI2V-5B-Turbo GGUF via LightX2V + MuseTalk lip-sync)
|
||||||
|
video:
|
||||||
|
enabled: true # master toggle — when false, video models are not loaded
|
||||||
|
backend: lightx2v # only option for now
|
||||||
|
mode: reflective # "library" (pre-baked clips) | "reflective" (fresh per turn)
|
||||||
|
resolution: 480 # 480 or 720
|
||||||
|
fps: 16 # Wan2.2 native rate; MuseTalk resamples as needed
|
||||||
|
|
||||||
|
library:
|
||||||
|
base_clip_count: 4 # how many speaking base clips to pre-generate per avatar
|
||||||
|
base_clip_seconds: 6 # duration of each pre-baked clip
|
||||||
|
|
||||||
|
# MuseTalk audio-driven lip-sync. When disabled, Wan2.2 base frames are
|
||||||
|
# used as-is without a lip-sync pass — useful when MuseTalk isn't installed
|
||||||
|
# or while iterating on the base pipeline.
|
||||||
|
musetalk:
|
||||||
|
enabled: false # toggle lip-sync on/off
|
||||||
|
|
||||||
|
reflective:
|
||||||
|
clip_seconds: 5 # target length of each fresh Wan2.2 clip per turn
|
||||||
|
clip_prompt_template: >-
|
||||||
|
webcam view of a person speaking, {reply_hint},
|
||||||
|
casual gestures, natural lighting, soft focus background
|
||||||
|
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
|
||||||
|
|
||||||
|
# Model sources for the video stack. T5/VAE/tokenizer come from the
|
||||||
|
# Wan-AI base repo. The single dense DIT comes from wan22_dit_repo as
|
||||||
|
# GGUF (Turbo 4-step distill). Both repos download on first run into
|
||||||
|
# HF_HOME=/cache/huggingface.
|
||||||
|
#
|
||||||
|
# Supported dit_quant_scheme values (dense 5B Turbo — GGUF only):
|
||||||
|
# gguf-Q8_0 — 8-bit, ~6 GB DIT, ~6.5 GB VRAM at load (default)
|
||||||
|
# gguf-Q4_K_M — 4-bit, ~3.5 GB DIT, lower VRAM for tight budgets
|
||||||
|
# (any gguf-<level> published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF)
|
||||||
|
models:
|
||||||
|
wan22_base_repo: Wan-AI/Wan2.2-TI2V-5B
|
||||||
|
wan22_dit_repo: hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF
|
||||||
|
wan22_dit_quant_scheme: gguf-Q8_0
|
||||||
|
wan22_t5_quantized: true
|
||||||
|
wan22_model_cls: wan2.2
|
||||||
|
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json
|
||||||
|
musetalk_path: TMElyralab/MuseTalk
|
||||||
|
|
||||||
|
# LoRAs applied to the dense 5B DIT at load time via LightX2V's
|
||||||
|
# lora_dynamic_apply path (merged during GGUF dequant). Dense has a
|
||||||
|
# single set of weights so `target` is always `both`.
|
||||||
|
#
|
||||||
|
# The old MoE-trained wan22-H-e8 / wan22-L-e8 LoRAs are NOT compatible
|
||||||
|
# with the 5B DIT and are disabled here. Future 5B-compatible LoRAs
|
||||||
|
# should follow the shape shown below.
|
||||||
|
loras: []
|
||||||
|
# loras:
|
||||||
|
# - path: /cache/loras/your-5b-lora.safetensors
|
||||||
|
# weight: 1.0
|
||||||
|
# target: both
|
||||||
|
# name: your-5b-lora
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
{
|
||||||
|
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by Wan22Pipeline.",
|
||||||
|
|
||||||
|
"infer_steps": 4,
|
||||||
|
"target_video_length": 81,
|
||||||
|
"text_len": 512,
|
||||||
|
|
||||||
|
"resize_mode": "adaptive",
|
||||||
|
"resolution": "480p",
|
||||||
|
"target_height": 480,
|
||||||
|
"target_width": 480,
|
||||||
|
"fps": 16,
|
||||||
|
|
||||||
|
"vae_stride": [4, 16, 16],
|
||||||
|
"num_channels_latents": 48,
|
||||||
|
|
||||||
|
"self_attn_1_type": "torch_sdpa",
|
||||||
|
"cross_attn_1_type": "torch_sdpa",
|
||||||
|
"cross_attn_2_type": "torch_sdpa",
|
||||||
|
"modulate_type": "torch",
|
||||||
|
"rope_type": "torch",
|
||||||
|
|
||||||
|
"sample_guide_scale": 1.0,
|
||||||
|
"sample_shift": 5.0,
|
||||||
|
"enable_cfg": false,
|
||||||
|
|
||||||
|
"cpu_offload": false,
|
||||||
|
"offload_granularity": "model",
|
||||||
|
"t5_cpu_offload": true,
|
||||||
|
"vae_cpu_offload": false,
|
||||||
|
|
||||||
|
"use_image_encoder": false,
|
||||||
|
|
||||||
|
"denoising_step_list": [1000, 750, 500, 250],
|
||||||
|
|
||||||
|
"dit_quantized": true,
|
||||||
|
"dit_quant_scheme": "gguf-Q8_0",
|
||||||
|
"t5_quantized": true,
|
||||||
|
"t5_quant_scheme": "fp8-sgl"
|
||||||
|
}
|
||||||
+11
@@ -0,0 +1,11 @@
|
|||||||
|
"""Pytest configuration.
|
||||||
|
|
||||||
|
Ensures the project root is on ``sys.path`` so tests can import ``server.*``
|
||||||
|
without installing the project as a package.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if _ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _ROOT)
|
||||||
@@ -6,10 +6,17 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
# Cache models on the host so they survive container rebuilds
|
# Cache models on the host so they survive container rebuilds
|
||||||
- huggingface-cache:/cache/huggingface
|
- huggingface-cache:/cache/huggingface
|
||||||
|
# LoRA adapters — drop .safetensors files into ./loras on the host,
|
||||||
|
# reference them from config.yml as /cache/loras/<file>.safetensors
|
||||||
|
- ./loras:/cache/loras
|
||||||
|
# Avatar images uploaded via the web UI persist between restarts
|
||||||
|
- ./avatars:/app/avatars
|
||||||
# Mount source so you can edit code/config without rebuilding the image
|
# Mount source so you can edit code/config without rebuilding the image
|
||||||
- ./config.yml:/app/config.yml:ro
|
- ./config.yml:/app/config.yml:ro
|
||||||
|
- ./configs:/app/configs:ro
|
||||||
- ./server:/app/server:ro
|
- ./server:/app/server:ro
|
||||||
- ./static:/app/static:ro
|
- ./static:/app/static:ro
|
||||||
|
- ./tests:/app/tests
|
||||||
- ./run.py:/app/run.py:ro
|
- ./run.py:/app/run.py:ro
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
_out/
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
# LightX2V + Wan2.2-TI2V-5B-Turbo (GGUF) Experiment
|
||||||
|
|
||||||
|
Swap the 14B MoE distill for the dense 5B Turbo model, keeping the LightX2V backend.
|
||||||
|
Hypothesis: half the parameters → lower VRAM footprint (can coexist with the running
|
||||||
|
server) and faster per-step compute, with the Turbo 4-step distill preserving wall time.
|
||||||
|
|
||||||
|
## Config
|
||||||
|
|
||||||
|
- **Model**: `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF` (Q8_0 by default — swap to Q4_K_M via env)
|
||||||
|
- **Base repo** (configs, T5, VAE): `Wan-AI/Wan2.2-TI2V-5B`
|
||||||
|
- **model_cls**: `wan2.2` (dense, single DIT — not MoE)
|
||||||
|
- **Steps**: 4 (Turbo distill)
|
||||||
|
- **Resolution**: 480×480, 81 frames @ 16 fps
|
||||||
|
|
||||||
|
## Key implementation details
|
||||||
|
|
||||||
|
- **Dense model (`wan2.2`)**: Uses single DIT checkpoint, not MoE — requires different dtype patching than the 14B pipeline
|
||||||
|
- **GGUF dequant → fp16**: Requires `DTYPE=FP16` and patches for T5 (bf16→fp16 wrapper), VAE (→fp16), and DIT pre/post weights (fp32→fp16)
|
||||||
|
- **Wan 2.2 VAE**: 48 latent channels with 16× spatial compression (vs 16 channels / 8× for Wan 2.1) — config must set `vae_stride: [4,16,16]` and `num_channels_latents: 48`
|
||||||
|
- **fp8 T5**: Uses `lightx2v/Encoders` fp8 checkpoint (~4.9 GB vs ~11.4 GB bf16)
|
||||||
|
- **Blackwell (SM120)**: Needs `_patch_fp8_scaled_mm_for_blackwell` to replace sgl_kernel's fp8 GEMM
|
||||||
|
|
||||||
|
## Why a separate container
|
||||||
|
|
||||||
|
Reuses the existing `voice-chat-voice-chat` image (LightX2V already installed) but runs
|
||||||
|
under its own compose profile so it doesn't interfere with the live server volumes or
|
||||||
|
startup. Shares the HF cache volume so model downloads are reused.
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure main image is built
|
||||||
|
docker compose build voice-chat
|
||||||
|
|
||||||
|
# Stage model (downloads base + Turbo Q8 GGUF, ~6 GB)
|
||||||
|
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
|
||||||
|
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/setup_model.py
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
|
||||||
|
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Reports peak VRAM and wall time for an 81-frame 480p clip.
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
| Metric | Value |
|
||||||
|
|--------|-------|
|
||||||
|
| Model load | ~43s |
|
||||||
|
| VRAM after load | 6.53 GB |
|
||||||
|
| T5 encode | ~1s |
|
||||||
|
| VAE encode | ~0.5s |
|
||||||
|
|
||||||
|
Awaiting full end-to-end benchmark completion for wall time and peak VRAM.
|
||||||
|
|
||||||
|
## Go / no-go criteria
|
||||||
|
|
||||||
|
- **Go**: < 45s per 81-frame clip AND peak VRAM < 12 GB (leaves ~20 GB for the server)
|
||||||
|
- **No-go**: keep the 14B MoE Q4_K_M pipeline
|
||||||
|
|
||||||
|
### Baselines
|
||||||
|
|
||||||
|
- **vLLM-Omni + fp16 Turbo-5B**: 1663s / 22.5 GB — decisive no-go
|
||||||
|
- **LightX2V + 14B MoE Q4_K_M**: ~30s/clip, ~14.5 GB VRAM (current production pipeline)
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
{
|
||||||
|
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by setup_model.py / test_i2v.py.",
|
||||||
|
|
||||||
|
"infer_steps": 4,
|
||||||
|
"target_video_length": 81,
|
||||||
|
"text_len": 512,
|
||||||
|
|
||||||
|
"resize_mode": "adaptive",
|
||||||
|
"resolution": "480p",
|
||||||
|
"target_height": 480,
|
||||||
|
"target_width": 480,
|
||||||
|
"fps": 16,
|
||||||
|
|
||||||
|
"vae_stride": [4, 16, 16],
|
||||||
|
"num_channels_latents": 48,
|
||||||
|
|
||||||
|
"self_attn_1_type": "torch_sdpa",
|
||||||
|
"cross_attn_1_type": "torch_sdpa",
|
||||||
|
"cross_attn_2_type": "torch_sdpa",
|
||||||
|
"modulate_type": "torch",
|
||||||
|
"rope_type": "torch",
|
||||||
|
|
||||||
|
"sample_guide_scale": 1.0,
|
||||||
|
"sample_shift": 5.0,
|
||||||
|
"enable_cfg": false,
|
||||||
|
|
||||||
|
"cpu_offload": false,
|
||||||
|
"offload_granularity": "model",
|
||||||
|
"t5_cpu_offload": true,
|
||||||
|
"vae_cpu_offload": false,
|
||||||
|
|
||||||
|
"use_image_encoder": false,
|
||||||
|
|
||||||
|
"denoising_step_list": [1000, 750, 500, 250],
|
||||||
|
|
||||||
|
"dit_quantized": true,
|
||||||
|
"dit_quant_scheme": "gguf-Q8_0",
|
||||||
|
"t5_quantized": true,
|
||||||
|
"t5_quant_scheme": "fp8-sgl"
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
services:
|
||||||
|
lightx2v-5b:
|
||||||
|
image: voice-chat-voice-chat:latest
|
||||||
|
volumes:
|
||||||
|
- huggingface-cache:/cache/huggingface
|
||||||
|
- ../../:/app
|
||||||
|
working_dir: /app
|
||||||
|
environment:
|
||||||
|
- DTYPE=FP16
|
||||||
|
- HF_HOME=/cache/huggingface
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
shm_size: "8g"
|
||||||
|
ipc: host
|
||||||
|
profiles:
|
||||||
|
- experimental
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
huggingface-cache:
|
||||||
|
name: voice-chat_huggingface-cache
|
||||||
|
external: true
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
"""Stage Wan2.2-TI2V-5B-Turbo GGUF pipeline for LightX2V.
|
||||||
|
|
||||||
|
Downloads:
|
||||||
|
1. Base `Wan-AI/Wan2.2-TI2V-5B` snapshot (configs, T5, VAE — skip bf16 DIT shards).
|
||||||
|
2. Turbo Q8 GGUF DIT from `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF`.
|
||||||
|
|
||||||
|
Quant file can be overridden via GGUF_FILE env (default Q8_0).
|
||||||
|
|
||||||
|
Idempotent: huggingface_hub handles caching.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
|
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
|
||||||
|
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
GGUF_FILE = os.environ.get(
|
||||||
|
"GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf"
|
||||||
|
)
|
||||||
|
T5_FP8_REPO = "lightx2v/Encoders"
|
||||||
|
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
print(f"\n=== 1/2 Snapshot base pipeline {BASE_REPO} ===", flush=True)
|
||||||
|
# The base repo ships bf16 DIT shards we don't need (we use the Turbo GGUF instead).
|
||||||
|
base_dir = snapshot_download(
|
||||||
|
repo_id=BASE_REPO,
|
||||||
|
ignore_patterns=[
|
||||||
|
"*.pt",
|
||||||
|
"diffusion_pytorch_model*.safetensors",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(f"Base pipeline at: {base_dir}")
|
||||||
|
|
||||||
|
print(f"\n=== 2/3 Download {GGUF_FILE} from {GGUF_REPO} ===", flush=True)
|
||||||
|
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
|
||||||
|
print(f"GGUF DIT at: {gguf_path}")
|
||||||
|
|
||||||
|
print(f"\n=== 3/3 Download fp8 T5 from {T5_FP8_REPO} ===", flush=True)
|
||||||
|
t5_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||||
|
print(f"fp8 T5 at: {t5_path}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 50}")
|
||||||
|
print("Ready. Export to test_i2v.py via env:")
|
||||||
|
print(f" BASE_DIR={base_dir}")
|
||||||
|
print(f" DIT_GGUF={gguf_path}")
|
||||||
|
print(f" T5_FP8={t5_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V.
|
||||||
|
|
||||||
|
Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint.
|
||||||
|
Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16
|
||||||
|
wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16).
|
||||||
|
|
||||||
|
Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\
|
||||||
|
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
|
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
|
||||||
|
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf")
|
||||||
|
T5_FP8_REPO = "lightx2v/Encoders"
|
||||||
|
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||||
|
|
||||||
|
CONFIG_JSON = Path(__file__).parent / "config.json"
|
||||||
|
SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png"
|
||||||
|
OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4")
|
||||||
|
|
||||||
|
PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background"
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
def human_gb(n: int) -> str:
|
||||||
|
return f"{n / (1024 ** 3):.2f} GB"
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||||
|
"""Recursively find and cast fp32 tensors to fp16 on any object tree."""
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in visited or depth > 6:
|
||||||
|
return 0
|
||||||
|
visited.add(obj_id)
|
||||||
|
n = 0
|
||||||
|
for attr_name in dir(obj):
|
||||||
|
if attr_name.startswith("__"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
val = getattr(obj, attr_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||||
|
try:
|
||||||
|
setattr(obj, attr_name, val.to(torch.float16))
|
||||||
|
n += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif hasattr(val, "__dict__") and not callable(val):
|
||||||
|
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
def build_args(model_path: str, config_json: str) -> argparse.Namespace:
|
||||||
|
return argparse.Namespace(
|
||||||
|
seed=SEED,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
task="i2v",
|
||||||
|
support_tasks=[],
|
||||||
|
model_path=model_path,
|
||||||
|
sf_model_path=None,
|
||||||
|
config_json=config_json,
|
||||||
|
use_prompt_enhancer=False,
|
||||||
|
prompt="",
|
||||||
|
negative_prompt="",
|
||||||
|
image_path="",
|
||||||
|
last_frame_path="",
|
||||||
|
audio_path="",
|
||||||
|
image_strength="1.0",
|
||||||
|
image_frame_idx="",
|
||||||
|
src_ref_images=None,
|
||||||
|
src_video=None,
|
||||||
|
src_mask=None,
|
||||||
|
src_pose_path=None,
|
||||||
|
src_face_path=None,
|
||||||
|
src_bg_path=None,
|
||||||
|
src_mask_path=None,
|
||||||
|
pose=None,
|
||||||
|
action_path=None,
|
||||||
|
action_ckpt=None,
|
||||||
|
save_result_path=None,
|
||||||
|
return_result_tensor=False,
|
||||||
|
target_shape=[],
|
||||||
|
target_video_length=81,
|
||||||
|
aspect_ratio="",
|
||||||
|
video_path=None,
|
||||||
|
sr_ratio=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"\n=== Stage model ===", flush=True)
|
||||||
|
base_dir = snapshot_download(
|
||||||
|
repo_id=BASE_REPO,
|
||||||
|
ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"],
|
||||||
|
)
|
||||||
|
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
|
||||||
|
t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||||
|
print(f" base_dir: {base_dir}")
|
||||||
|
print(f" dit_gguf: {gguf_path}")
|
||||||
|
print(f" t5_fp8: {t5_fp8_path}")
|
||||||
|
|
||||||
|
with open(CONFIG_JSON, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
cfg.pop("_comment", None)
|
||||||
|
cfg["dit_quantized_ckpt"] = gguf_path
|
||||||
|
cfg["t5_quantized_ckpt"] = t5_fp8_path
|
||||||
|
|
||||||
|
tmp = tempfile.NamedTemporaryFile(
|
||||||
|
prefix="wan22_5b_", suffix=".json",
|
||||||
|
mode="w", delete=False, encoding="utf-8",
|
||||||
|
)
|
||||||
|
json.dump(cfg, tmp, indent=2)
|
||||||
|
tmp.close()
|
||||||
|
print(f" runtime config: {tmp.name}")
|
||||||
|
|
||||||
|
# Import LightX2V after env is set.
|
||||||
|
import sys
|
||||||
|
if "/app" not in sys.path:
|
||||||
|
sys.path.insert(0, "/app")
|
||||||
|
from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell
|
||||||
|
from lightx2v.infer import init_runner
|
||||||
|
from lightx2v.utils.input_info import (
|
||||||
|
init_empty_input_info,
|
||||||
|
update_input_info_from_dict,
|
||||||
|
)
|
||||||
|
from lightx2v.utils.set_config import set_config
|
||||||
|
|
||||||
|
_patch_fp8_scaled_mm_for_blackwell()
|
||||||
|
|
||||||
|
args = build_args(base_dir, tmp.name)
|
||||||
|
print(f"\n=== set_config + init_runner ===", flush=True)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
t_load = time.perf_counter()
|
||||||
|
config = set_config(args)
|
||||||
|
runner = init_runner(config)
|
||||||
|
load_time = time.perf_counter() - t_load
|
||||||
|
vram_load = torch.cuda.memory_allocated()
|
||||||
|
print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True)
|
||||||
|
|
||||||
|
# GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match.
|
||||||
|
os.environ["DTYPE"] = "FP16"
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
|
||||||
|
# Reuse the patch methods from Wan22Pipeline as standalone functions
|
||||||
|
# by constructing a minimal object with ._runner.
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
|
||||||
|
shim = Wan22Pipeline.__new__(Wan22Pipeline)
|
||||||
|
shim._runner = runner
|
||||||
|
Wan22Pipeline._patch_t5_dtype_for_gguf(shim)
|
||||||
|
Wan22Pipeline._patch_vae_dtype_for_gguf(shim)
|
||||||
|
Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim)
|
||||||
|
|
||||||
|
# Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model).
|
||||||
|
# For dense wan2.2, the model is directly at runner.model — patch it explicitly.
|
||||||
|
n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model)
|
||||||
|
print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)")
|
||||||
|
# Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc).
|
||||||
|
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||||
|
print(f"[patch] cast {n_extra} additional fp32 tensors to fp16")
|
||||||
|
|
||||||
|
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||||
|
|
||||||
|
print(f"\n=== generate ===", flush=True)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||||
|
out_path = tf.name
|
||||||
|
|
||||||
|
update_input_info_from_dict(
|
||||||
|
input_info,
|
||||||
|
{
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": PROMPT,
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": SAMPLE_AVATAR,
|
||||||
|
"save_result_path": out_path,
|
||||||
|
"target_video_length": 81,
|
||||||
|
"return_result_tensor": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
t_gen = time.perf_counter()
|
||||||
|
runner.run_pipeline(input_info)
|
||||||
|
gen_time = time.perf_counter() - t_gen
|
||||||
|
peak_vram = torch.cuda.max_memory_allocated()
|
||||||
|
print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True)
|
||||||
|
|
||||||
|
if os.path.exists(out_path):
|
||||||
|
size = os.path.getsize(out_path)
|
||||||
|
OUTPUT.write_bytes(Path(out_path).read_bytes())
|
||||||
|
os.remove(out_path)
|
||||||
|
print(f"[output] wrote {OUTPUT} ({size} bytes)")
|
||||||
|
else:
|
||||||
|
print(f"[output] no file at {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -14,3 +14,15 @@ soundfile
|
|||||||
scipy
|
scipy
|
||||||
python-multipart
|
python-multipart
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|
||||||
|
# --- Avatar video (optional, only used when config.video.enabled=true) ---
|
||||||
|
# Video frame I/O (used by video_models/wan22.py and the muxer).
|
||||||
|
imageio[ffmpeg]>=2.34
|
||||||
|
av>=12.0
|
||||||
|
pyzmq>=25.0
|
||||||
|
gguf>=0.6.0
|
||||||
|
# sgl-kernel: installed from SGLang's cu128 wheel index in Dockerfile
|
||||||
|
# (PyPI version lacks SM120/Blackwell CUDA kernels)
|
||||||
|
# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the
|
||||||
|
# Dockerfile because neither ships a stable PyPI release yet. See lines
|
||||||
|
# "LightX2V from source" / "MuseTalk from source" in Dockerfile.
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
# Agent guide — server/
|
||||||
|
|
||||||
|
The audio pipeline and FastAPI surface. The video stack lives under [video_models/](video_models/) and has its own guide.
|
||||||
|
|
||||||
|
## Module map
|
||||||
|
|
||||||
|
- [main.py](main.py) — FastAPI app, lifespan, `/ws/chat` WebSocket, video/avatar HTTP endpoints. Keep business logic out of here; this is a transport layer.
|
||||||
|
- [models.py](models.py) — `ModelManager.load_all()`. All models are loaded once at startup in a fixed order: VAD → ASR → LLM → TTS → (optional) Video. `video_engine` stays `None` when `config.video.enabled` is false — callers MUST tolerate that.
|
||||||
|
- [config.py](config.py) — thin YAML loader, exposes a single `config` dict. Don't scatter `yaml.safe_load` elsewhere.
|
||||||
|
- [pipeline.py](pipeline.py) — `ConversationSession`, one instance per WebSocket. Owns per-session state (VAD stream, conversation history, KV cache, cancel event). Orchestrates VAD → ASR → LLM → TTS and optionally the video branch.
|
||||||
|
- [vad.py](vad.py) — Silero VAD via ONNX Runtime on CPU. `StreamingVAD.process_chunk(pcm_16k) → utterance | None`. Returns a full utterance only on speech→silence transition.
|
||||||
|
- [asr.py](asr.py) — Qwen3-ASR wrapper. Sync, called under `asyncio.to_thread`.
|
||||||
|
- [llm.py](llm.py) — two backends behind a common `generate(history, max_new_tokens, kv_cache_state) → (text, KVCacheState)` signature: `LLMEngine` (local transformers) and `LMStudioEngine` (HTTP, no KV cache).
|
||||||
|
- [tts.py](tts.py) — Kokoro wrapper. The per-segment generator yields `(graphemes, _ps, audio_f32)` tuples at 24 kHz.
|
||||||
|
- [audio_utils.py](audio_utils.py) — `pcm_bytes_to_float32` / `float32_to_pcm_bytes`. The WebSocket protocol is 16 kHz int16 PCM in, 24 kHz int16 PCM out.
|
||||||
|
- [video.py](video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine`. Orchestrates Wan2.2 + MuseTalk. Gated by `config.video.enabled`.
|
||||||
|
- [video_models/](video_models/) — see [video_models/AGENT.md](video_models/AGENT.md) for Blackwell/GGUF gotchas.
|
||||||
|
|
||||||
|
## Session lifecycle (pipeline.py)
|
||||||
|
|
||||||
|
1. Client connects → `ConversationSession(models, send_json, send_bytes)` → `start()` emits `{type: "status", state: "listening"}` and a `{type: "video_mode", ...}` hint.
|
||||||
|
2. Inbound binary frames are 16 kHz int16 PCM → `handle_audio_chunk` → VAD. On speech→silence the session kicks off `_process_utterance` as an `asyncio.Task` so it never blocks the WebSocket receive loop.
|
||||||
|
3. `_process_utterance` flow: ASR → LLM → TTS stream. Each blocking call wraps in `asyncio.to_thread`.
|
||||||
|
4. TTS output is split into short segments via `_split_into_segments` before synthesis so streaming chunks stay small.
|
||||||
|
5. TTS runs on a background `threading.Thread`, feeding a `queue.Queue` that the async loop drains.
|
||||||
|
|
||||||
|
## Barge-in
|
||||||
|
|
||||||
|
- `cancel_event: threading.Event` is the single stop signal. Checked between every stage and inside the TTS queue drain loop.
|
||||||
|
- Two ways to trigger: new VAD utterance while `is_responding` (`handle_audio_chunk`), or an explicit `{type: "interrupt", last_chunk_id}` text message (`interrupt`).
|
||||||
|
- On cancel: set the event, send `{type: "interrupt"}` to the client so it flushes its audio buffer, and discard any pending video clip.
|
||||||
|
- Don't add work after `cancel_event.is_set()` without checking — that's how zombie audio/video reaches the client after a barge-in.
|
||||||
|
|
||||||
|
## Video branch
|
||||||
|
|
||||||
|
The audio pipeline changes shape when `video_engine.is_ready()`:
|
||||||
|
|
||||||
|
- PCM chunks and `response_text` are **not** streamed during the turn — they're buffered.
|
||||||
|
- TTS audio is concatenated into one float32 array.
|
||||||
|
- After TTS completes, `video_engine.generate_speaking_clip(audio, sr, reply_text)` renders the MP4 (blocking, wrapped in `to_thread`).
|
||||||
|
- The full clip + final text is sent as a single `speaking_clip` message.
|
||||||
|
|
||||||
|
If you extend the audio pipeline, preserve this dual-mode behaviour: the client's UX is very different between the two paths, and mixing them (e.g. sending PCM *and* a clip) will double-play the audio.
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
|
||||||
|
- Dataclasses for structured config (see `VideoConfig`, `LoRASpec`). Parse once in `*.from_dict`; don't re-read `config.yml` mid-session.
|
||||||
|
- `asyncio.to_thread` for any sync model call from an async context. Never call `.generate()` / `.transcribe()` / `.pipeline()` directly on the event loop.
|
||||||
|
- Locks: `VideoEngine._lock` serialises model state mutations; `ConversationSession` is not locked because each WebSocket gets its own instance.
|
||||||
|
- Logging: one `log = logging.getLogger(__name__)` per module. INFO for lifecycle and per-turn milestones; avoid DEBUG spam in hot loops.
|
||||||
|
- Keep `main.py` thin. New endpoints should delegate to a method on `ModelManager` or an engine class.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- `tests/unit/test_pipeline_video_branch.py` — the video vs. audio path selection. Update it if you change the `use_video` condition.
|
||||||
|
- `tests/unit/test_video_config.py` / `test_video_engine_logic.py` — config parsing and the pure logic in `video.py`.
|
||||||
|
- Component tests live in `tests/component/` and require the Docker GPU environment. See [tests/README.md](../tests/README.md).
|
||||||
+121
-2
@@ -1,23 +1,27 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.params import Form
|
from fastapi.params import Form
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse, Response
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from server.audio_utils import pcm_bytes_to_float32
|
from server.audio_utils import pcm_bytes_to_float32
|
||||||
from server.models import ModelManager
|
from server.models import ModelManager
|
||||||
from server.pipeline import ConversationSession
|
from server.pipeline import ConversationSession
|
||||||
|
from server.video import LoRASpec
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
||||||
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||||
|
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
|
||||||
|
os.makedirs(AVATAR_DIR, exist_ok=True)
|
||||||
|
|
||||||
model_mgr = ModelManager()
|
model_mgr = ModelManager()
|
||||||
|
|
||||||
@@ -47,6 +51,110 @@ async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
|||||||
return {"status": "ok", "voice": voice}
|
return {"status": "ok", "voice": voice}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Video / avatar endpoints ---------------------------------------------
|
||||||
|
|
||||||
|
def _require_video() -> "object":
|
||||||
|
"""Return the video engine, or raise 404 if video mode isn't enabled."""
|
||||||
|
ve = model_mgr.video_engine
|
||||||
|
if ve is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Video engine disabled. Set config.video.enabled=true and restart.",
|
||||||
|
)
|
||||||
|
return ve
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/set-avatar")
|
||||||
|
async def set_avatar(image: UploadFile):
|
||||||
|
"""Upload an avatar image and (re)generate cached clips."""
|
||||||
|
ve = _require_video()
|
||||||
|
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
|
||||||
|
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
|
||||||
|
with open(dest, "wb") as f:
|
||||||
|
f.write(await image.read())
|
||||||
|
log.info("Avatar saved to %s", dest)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(ve.set_avatar, dest)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("set_avatar failed")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"avatar_path": dest,
|
||||||
|
"idle_clip_url": "/api/idle-clip",
|
||||||
|
"mode": ve.cfg.mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/idle-clip")
|
||||||
|
async def idle_clip():
|
||||||
|
"""Return the cached idle loop MP4."""
|
||||||
|
ve = _require_video()
|
||||||
|
data = ve.get_idle_clip()
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
|
||||||
|
return Response(content=data, media_type="video/mp4")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/set-video-mode")
|
||||||
|
async def set_video_mode(mode: str = Form(...)):
|
||||||
|
"""Switch between 'off', 'library', and 'reflective'.
|
||||||
|
|
||||||
|
'off' leaves the video engine loaded but makes the pipeline take the
|
||||||
|
PCM streaming path on subsequent turns (by marking the engine not-ready
|
||||||
|
from the client's perspective via a simple flag).
|
||||||
|
"""
|
||||||
|
ve = _require_video()
|
||||||
|
if mode not in ("off", "library", "reflective"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="mode must be one of: off, library, reflective",
|
||||||
|
)
|
||||||
|
# Switching between library/reflective changes how set_avatar prebakes
|
||||||
|
# clips. Require a fresh avatar upload afterwards to re-bake.
|
||||||
|
ve.cfg.mode = mode
|
||||||
|
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/reload-loras")
|
||||||
|
async def reload_loras(body: dict):
|
||||||
|
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
|
||||||
|
|
||||||
|
Regenerates the idle clip if an avatar is already set, since the new
|
||||||
|
LoRAs change the base style.
|
||||||
|
"""
|
||||||
|
ve = _require_video()
|
||||||
|
raw = body.get("loras") or []
|
||||||
|
specs: list[LoRASpec] = []
|
||||||
|
for entry in raw:
|
||||||
|
if not entry or "path" not in entry:
|
||||||
|
continue
|
||||||
|
target = str(entry.get("target", "both")).lower()
|
||||||
|
if target != "both":
|
||||||
|
target = "both"
|
||||||
|
specs.append(
|
||||||
|
LoRASpec(
|
||||||
|
path=str(entry["path"]),
|
||||||
|
weight=float(entry.get("weight", 1.0)),
|
||||||
|
target=target, # type: ignore[arg-type]
|
||||||
|
name=entry.get("name"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(ve.load_loras, specs)
|
||||||
|
if ve.avatar_path:
|
||||||
|
log.info("Regenerating idle clip after LoRA reload.")
|
||||||
|
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("reload_loras failed")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws/chat")
|
@app.websocket("/ws/chat")
|
||||||
async def websocket_chat(ws: WebSocket):
|
async def websocket_chat(ws: WebSocket):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
@@ -61,6 +169,17 @@ async def websocket_chat(ws: WebSocket):
|
|||||||
session = ConversationSession(model_mgr, send_json, send_bytes)
|
session = ConversationSession(model_mgr, send_json, send_bytes)
|
||||||
await session.start()
|
await session.start()
|
||||||
|
|
||||||
|
# Tell the client whether video mode is active so it knows whether to
|
||||||
|
# suppress PCM playback and wait for speaking_clip messages instead.
|
||||||
|
ve = model_mgr.video_engine
|
||||||
|
await send_json({
|
||||||
|
"type": "video_mode",
|
||||||
|
"enabled": ve is not None,
|
||||||
|
"ready": ve.is_ready() if ve is not None else False,
|
||||||
|
"mode": ve.cfg.mode if ve is not None else "off",
|
||||||
|
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
|
||||||
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
message = await ws.receive()
|
message = await ws.receive()
|
||||||
|
|||||||
+25
-2
@@ -5,6 +5,7 @@ from server.vad import StreamingVAD
|
|||||||
from server.asr import ASREngine
|
from server.asr import ASREngine
|
||||||
from server.llm import LLMEngine
|
from server.llm import LLMEngine
|
||||||
from server.tts import TTSEngine
|
from server.tts import TTSEngine
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ class ModelManager:
|
|||||||
self.asr_engine: ASREngine | None = None
|
self.asr_engine: ASREngine | None = None
|
||||||
self.llm_engine: LLMEngine | None = None
|
self.llm_engine: LLMEngine | None = None
|
||||||
self.tts_engine: TTSEngine | None = None
|
self.tts_engine: TTSEngine | None = None
|
||||||
|
self.video_engine: VideoEngine | None = None
|
||||||
|
|
||||||
def load_all(self):
|
def load_all(self):
|
||||||
"""Load all models sequentially. Call from the main process."""
|
"""Load all models sequentially. Call from the main process."""
|
||||||
@@ -38,6 +40,7 @@ class ModelManager:
|
|||||||
self._load_asr()
|
self._load_asr()
|
||||||
self._load_llm()
|
self._load_llm()
|
||||||
self._load_tts()
|
self._load_tts()
|
||||||
|
self._load_video()
|
||||||
log.info("All models loaded successfully.")
|
log.info("All models loaded successfully.")
|
||||||
|
|
||||||
def _load_vad(self):
|
def _load_vad(self):
|
||||||
@@ -84,8 +87,8 @@ class ModelManager:
|
|||||||
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
model_name = "Qwen/Qwen3.5-0.8B"
|
# model_name = "Qwen/Qwen3.5-0.8B"
|
||||||
|
model_name = "dphn/Dolphin-X1-8B-FP8"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
device = get_device()
|
device = get_device()
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -101,6 +104,26 @@ class ModelManager:
|
|||||||
self.tts_engine = TTSEngine()
|
self.tts_engine = TTSEngine()
|
||||||
log.info("Kokoro TTS loaded.")
|
log.info("Kokoro TTS loaded.")
|
||||||
|
|
||||||
|
def _load_video(self):
|
||||||
|
"""Load the avatar video stack iff config.video.enabled is true.
|
||||||
|
|
||||||
|
Leaves ``video_engine`` as None when disabled so existing voice flow
|
||||||
|
is untouched. Later phases replace this stub with actual Wan2.2 +
|
||||||
|
MuseTalk loading inside ``VideoEngine``.
|
||||||
|
"""
|
||||||
|
from server.config import config
|
||||||
|
|
||||||
|
video_cfg_raw = config.get("video", {}) or {}
|
||||||
|
if not video_cfg_raw.get("enabled", False):
|
||||||
|
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info("Loading avatar video engine...")
|
||||||
|
cfg = VideoConfig.from_dict(video_cfg_raw)
|
||||||
|
self.video_engine = VideoEngine(cfg)
|
||||||
|
self.video_engine.load_models()
|
||||||
|
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
|
||||||
|
|
||||||
def create_vad(self) -> StreamingVAD:
|
def create_vad(self) -> StreamingVAD:
|
||||||
"""Create a new StreamingVAD instance for a client session."""
|
"""Create a new StreamingVAD instance for a client session."""
|
||||||
return StreamingVAD(self.vad_model)
|
return StreamingVAD(self.vad_model)
|
||||||
|
|||||||
+60
-1
@@ -157,11 +157,20 @@ class ConversationSession:
|
|||||||
|
|
||||||
# TTS - stream chunks with per-sentence text
|
# TTS - stream chunks with per-sentence text
|
||||||
await self.send_json({"type": "status", "state": "speaking"})
|
await self.send_json({"type": "status", "state": "speaking"})
|
||||||
|
|
||||||
|
# Video-mode branch: if a video engine is loaded AND an avatar is
|
||||||
|
# set, buffer the full TTS output into a single blob, run MuseTalk
|
||||||
|
# lip-sync (library or reflective source), mux to MP4, and send the
|
||||||
|
# full clip + text in one shot. The client plays the MP4 (which
|
||||||
|
# carries audio) instead of the per-chunk PCM path.
|
||||||
|
video_engine = getattr(self.models, "video_engine", None)
|
||||||
|
use_video = video_engine is not None and video_engine.is_ready()
|
||||||
|
|
||||||
chunk_queue = queue.Queue()
|
chunk_queue = queue.Queue()
|
||||||
self._last_played_chunk_id = None
|
self._last_played_chunk_id = None
|
||||||
|
|
||||||
segments = _split_into_segments(response)
|
segments = _split_into_segments(response)
|
||||||
log.info(f"TTS: split response into {len(segments)} segments")
|
log.info(f"TTS: split response into {len(segments)} segments (video={use_video})")
|
||||||
|
|
||||||
def _tts_worker():
|
def _tts_worker():
|
||||||
try:
|
try:
|
||||||
@@ -187,6 +196,10 @@ class ConversationSession:
|
|||||||
chunk_id = 0
|
chunk_id = 0
|
||||||
# Maps chunk_id -> cumulative text up to and including that chunk
|
# Maps chunk_id -> cumulative text up to and including that chunk
|
||||||
chunk_text_map: dict[int, str] = {}
|
chunk_text_map: dict[int, str] = {}
|
||||||
|
# Video mode accumulator: we buffer all TTS audio into one float32
|
||||||
|
# array so MuseTalk can align against the full utterance.
|
||||||
|
audio_buffer: list[np.ndarray] = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||||
@@ -202,6 +215,12 @@ class ConversationSession:
|
|||||||
spoken_text += sentence_text
|
spoken_text += sentence_text
|
||||||
chunk_text_map[chunk_id] = spoken_text
|
chunk_text_map[chunk_id] = spoken_text
|
||||||
|
|
||||||
|
if use_video:
|
||||||
|
audio_buffer.append(audio)
|
||||||
|
# Don't stream text or PCM during video mode — we'll send
|
||||||
|
# everything after the clip renders so the client doesn't
|
||||||
|
# start displaying text before the video is ready.
|
||||||
|
else:
|
||||||
await self.send_json({
|
await self.send_json({
|
||||||
"type": "response_text",
|
"type": "response_text",
|
||||||
"text": sentence_text,
|
"text": sentence_text,
|
||||||
@@ -219,6 +238,46 @@ class ConversationSession:
|
|||||||
|
|
||||||
tts_thread.join(timeout=2.0)
|
tts_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Video mode: render the speaking clip now that TTS is done.
|
||||||
|
if use_video and audio_buffer and not self.cancel_event.is_set():
|
||||||
|
try:
|
||||||
|
full_audio = np.concatenate(audio_buffer).astype(np.float32)
|
||||||
|
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
|
||||||
|
log.info(
|
||||||
|
"Video: rendering speaking clip (audio=%ds, mode=%s)",
|
||||||
|
int(len(full_audio) / sample_rate), video_engine.cfg.mode,
|
||||||
|
)
|
||||||
|
mp4_bytes = await asyncio.to_thread(
|
||||||
|
video_engine.generate_speaking_clip,
|
||||||
|
full_audio,
|
||||||
|
sample_rate,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
if self.cancel_event.is_set():
|
||||||
|
log.info("Video clip discarded (cancelled during render).")
|
||||||
|
else:
|
||||||
|
duration_ms = int(len(full_audio) / sample_rate * 1000)
|
||||||
|
await self.send_json({
|
||||||
|
"type": "speaking_clip",
|
||||||
|
"chunk_id": 0,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"text": response,
|
||||||
|
"size_bytes": len(mp4_bytes),
|
||||||
|
})
|
||||||
|
await self.send_bytes(mp4_bytes)
|
||||||
|
except Exception:
|
||||||
|
log.exception("Video speaking-clip render failed; falling back silently.")
|
||||||
|
# Best-effort: tell the client nothing was spoken visually.
|
||||||
|
try:
|
||||||
|
await self.send_json({
|
||||||
|
"type": "response_text",
|
||||||
|
"text": response,
|
||||||
|
"chunk_id": 0,
|
||||||
|
"final": True,
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Determine what was actually heard by the client
|
# Determine what was actually heard by the client
|
||||||
was_interrupted = spoken_text.strip() != response.strip()
|
was_interrupted = spoken_text.strip() != response.strip()
|
||||||
if was_interrupted and self._last_played_chunk_id is not None:
|
if was_interrupted and self._last_played_chunk_id is not None:
|
||||||
|
|||||||
+419
@@ -0,0 +1,419 @@
|
|||||||
|
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
|
||||||
|
|
||||||
|
Top-level orchestrator. The heavy 3rd-party model code is isolated in
|
||||||
|
``server/video_models/`` so each wrapper can be updated independently.
|
||||||
|
|
||||||
|
This module is only imported by ``server/models.py`` when
|
||||||
|
``config.video.enabled`` is true. When disabled, the existing voice pipeline
|
||||||
|
is completely untouched.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
LoRATarget = Literal["both"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRASpec:
|
||||||
|
"""One LoRA adapter entry from ``config.video.loras``.
|
||||||
|
|
||||||
|
The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE
|
||||||
|
experts), so ``target`` is always ``"both"``. The field is kept for
|
||||||
|
forward compatibility and config-file compatibility with older MoE
|
||||||
|
configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced
|
||||||
|
to ``"both"`` in ``VideoConfig.from_dict``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
weight: float = 1.0
|
||||||
|
target: LoRATarget = "both"
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoConfig:
|
||||||
|
"""Flattened view of the ``video:`` section of config.yml."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
backend: str = "lightx2v"
|
||||||
|
mode: str = "reflective" # "library" | "reflective"
|
||||||
|
resolution: int = 480
|
||||||
|
fps: int = 16
|
||||||
|
library_base_clip_count: int = 4
|
||||||
|
library_base_clip_seconds: int = 6
|
||||||
|
reflective_clip_seconds: int = 5
|
||||||
|
reflective_prompt_template: str = (
|
||||||
|
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
||||||
|
"natural lighting, soft focus background"
|
||||||
|
)
|
||||||
|
reflective_prompt_reply_words: int = 18
|
||||||
|
loras: list[LoRASpec] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Model paths — can be overridden via config.yml.video.models.
|
||||||
|
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||||
|
# The bf16 DIT shards in this repo are skipped — we
|
||||||
|
# replace them with a quantised GGUF from wan22_dit_repo.
|
||||||
|
# wan22_dit_repo : HF repo id (or local dir) providing the single
|
||||||
|
# dense GGUF DIT checkpoint (5B Turbo).
|
||||||
|
# wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default)
|
||||||
|
# or "gguf-Q4_K_M" for lower VRAM.
|
||||||
|
# wan22_config_json : path to the LightX2V inference config template the
|
||||||
|
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||||
|
wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B"
|
||||||
|
wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
wan22_dit_quant_scheme: str = "gguf-Q8_0"
|
||||||
|
wan22_t5_quantized: bool = True
|
||||||
|
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
wan22_model_cls: str = "wan2.2"
|
||||||
|
musetalk_enabled: bool = True
|
||||||
|
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, raw: dict) -> "VideoConfig":
|
||||||
|
raw = raw or {}
|
||||||
|
library = raw.get("library", {}) or {}
|
||||||
|
reflective = raw.get("reflective", {}) or {}
|
||||||
|
models_raw = raw.get("models", {}) or {}
|
||||||
|
loras_raw = raw.get("loras") or []
|
||||||
|
|
||||||
|
default_template = (
|
||||||
|
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
||||||
|
"natural lighting, soft focus background"
|
||||||
|
)
|
||||||
|
|
||||||
|
loras: list[LoRASpec] = []
|
||||||
|
for entry in loras_raw:
|
||||||
|
if not entry or "path" not in entry:
|
||||||
|
continue
|
||||||
|
target = str(entry.get("target", "both")).lower()
|
||||||
|
if target != "both":
|
||||||
|
log.warning(
|
||||||
|
"LoRA %s: target %r is MoE-era; coercing to 'both' "
|
||||||
|
"(dense 5B has a single DIT).",
|
||||||
|
entry.get("path"), target,
|
||||||
|
)
|
||||||
|
target = "both"
|
||||||
|
loras.append(
|
||||||
|
LoRASpec(
|
||||||
|
path=str(entry["path"]),
|
||||||
|
weight=float(entry.get("weight", 1.0)),
|
||||||
|
target=target, # type: ignore[arg-type]
|
||||||
|
name=entry.get("name"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
enabled=bool(raw.get("enabled", False)),
|
||||||
|
backend=str(raw.get("backend", "lightx2v")),
|
||||||
|
mode=str(raw.get("mode", "reflective")),
|
||||||
|
resolution=int(raw.get("resolution", 480)),
|
||||||
|
fps=int(raw.get("fps", 16)),
|
||||||
|
library_base_clip_count=int(library.get("base_clip_count", 4)),
|
||||||
|
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
|
||||||
|
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
|
||||||
|
reflective_prompt_template=str(
|
||||||
|
reflective.get("clip_prompt_template", default_template)
|
||||||
|
),
|
||||||
|
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
|
||||||
|
loras=loras,
|
||||||
|
wan22_base_repo=str(
|
||||||
|
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B")
|
||||||
|
),
|
||||||
|
wan22_dit_repo=str(
|
||||||
|
models_raw.get(
|
||||||
|
"wan22_dit_repo",
|
||||||
|
"hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
wan22_dit_quant_scheme=str(
|
||||||
|
models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0")
|
||||||
|
),
|
||||||
|
wan22_t5_quantized=bool(
|
||||||
|
models_raw.get("wan22_t5_quantized", True)
|
||||||
|
),
|
||||||
|
wan22_config_json=str(
|
||||||
|
models_raw.get(
|
||||||
|
"wan22_config_json",
|
||||||
|
"/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
wan22_model_cls=str(
|
||||||
|
models_raw.get("wan22_model_cls", "wan2.2")
|
||||||
|
),
|
||||||
|
musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True))
|
||||||
|
if isinstance(raw.get("musetalk"), dict)
|
||||||
|
else bool(raw.get("musetalk_enabled", True)),
|
||||||
|
musetalk_model_path=str(
|
||||||
|
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
|
||||||
|
# less repetitive when replayed. Kept module-level so tests can import them.
|
||||||
|
LIBRARY_BASE_PROMPTS = [
|
||||||
|
"webcam view of a person speaking, subtle head nods, casual expression, "
|
||||||
|
"natural lighting, soft focus background",
|
||||||
|
"webcam view of a person speaking, slight smile, gentle hand gesture, "
|
||||||
|
"natural lighting, soft focus background",
|
||||||
|
"webcam view of a person speaking, looking thoughtful, small head tilt, "
|
||||||
|
"natural lighting, soft focus background",
|
||||||
|
"webcam view of a person speaking, engaged and attentive, minor shoulder "
|
||||||
|
"movement, natural lighting, soft focus background",
|
||||||
|
"webcam view of a person speaking, relaxed posture, blinking naturally, "
|
||||||
|
"natural lighting, soft focus background",
|
||||||
|
]
|
||||||
|
|
||||||
|
IDLE_PROMPT = (
|
||||||
|
"webcam view of a person listening quietly, mouth closed, subtle "
|
||||||
|
"breathing, occasional blinks, calm expression, natural lighting, "
|
||||||
|
"soft focus background"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoEngine:
|
||||||
|
"""Top-level video generation orchestrator.
|
||||||
|
|
||||||
|
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
|
||||||
|
pre-rendered clips. Exposed to ``ConversationSession`` via
|
||||||
|
``ModelManager.video_engine``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg: VideoConfig):
|
||||||
|
self.cfg = cfg
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Avatar state
|
||||||
|
self.avatar_path: str | None = None
|
||||||
|
self.idle_clip_mp4: bytes | None = None
|
||||||
|
# Pre-baked speaking base clips for library mode. Each entry is a
|
||||||
|
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
|
||||||
|
self.speaking_base_frames: list[np.ndarray] = []
|
||||||
|
# Round-robin pointer for picking a library clip per turn
|
||||||
|
self._library_cursor = 0
|
||||||
|
|
||||||
|
# Model wrappers — instantiated lazily by ``load_models()`` so unit
|
||||||
|
# tests can exercise VideoEngine without touching CUDA at all.
|
||||||
|
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
|
||||||
|
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
|
||||||
|
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Model loading --------------------------------------------------
|
||||||
|
|
||||||
|
def load_models(self) -> None:
|
||||||
|
"""Instantiate the underlying model wrappers.
|
||||||
|
|
||||||
|
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
|
||||||
|
without triggering Wan2.2's ~12-16GB VRAM allocation.
|
||||||
|
"""
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
from server.video_models.musetalk import MuseTalkEngine
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
|
||||||
|
self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
|
||||||
|
self.cfg.wan22_dit_quant_scheme,
|
||||||
|
)
|
||||||
|
self._wan22 = Wan22Pipeline(
|
||||||
|
base_repo=self.cfg.wan22_base_repo,
|
||||||
|
dit_repo=self.cfg.wan22_dit_repo,
|
||||||
|
config_json=self.cfg.wan22_config_json,
|
||||||
|
model_cls=self.cfg.wan22_model_cls,
|
||||||
|
resolution=self.cfg.resolution,
|
||||||
|
fps=self.cfg.fps,
|
||||||
|
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
|
||||||
|
t5_quantized=self.cfg.wan22_t5_quantized,
|
||||||
|
)
|
||||||
|
if self.cfg.loras:
|
||||||
|
self._wan22.load_loras(self.cfg.loras)
|
||||||
|
log.info("Wan2.2 pipeline ready.")
|
||||||
|
|
||||||
|
if self.cfg.musetalk_enabled:
|
||||||
|
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||||
|
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||||
|
log.info("MuseTalk engine ready.")
|
||||||
|
else:
|
||||||
|
log.info("MuseTalk disabled via config — skipping lip-sync pass.")
|
||||||
|
self._musetalk = None
|
||||||
|
|
||||||
|
# --- Readiness ------------------------------------------------------
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""True when an avatar is set and a speaking clip can be produced."""
|
||||||
|
musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None
|
||||||
|
return (
|
||||||
|
self._wan22 is not None
|
||||||
|
and musetalk_ok
|
||||||
|
and self.avatar_path is not None
|
||||||
|
and self.idle_clip_mp4 is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- LoRA management ------------------------------------------------
|
||||||
|
|
||||||
|
def load_loras(self, specs: list[LoRASpec]) -> None:
|
||||||
|
"""Apply a list of LoRA adapters to the Wan2.2 base.
|
||||||
|
|
||||||
|
Replaces any previously applied LoRAs. Safe to call after init for
|
||||||
|
hot-reload via ``POST /api/reload-loras``.
|
||||||
|
"""
|
||||||
|
if self._wan22 is None:
|
||||||
|
raise RuntimeError("load_loras called before load_models()")
|
||||||
|
with self._lock:
|
||||||
|
self._wan22.unload_loras()
|
||||||
|
self._wan22.load_loras(specs)
|
||||||
|
self.cfg.loras = list(specs)
|
||||||
|
log.info("Applied %d LoRA(s): %s",
|
||||||
|
len(specs),
|
||||||
|
", ".join(s.name or s.path for s in specs) or "<none>")
|
||||||
|
|
||||||
|
# --- Avatar lifecycle ----------------------------------------------
|
||||||
|
|
||||||
|
def set_avatar(self, image_path: str) -> None:
|
||||||
|
"""Register an avatar image and pre-generate cached clips.
|
||||||
|
|
||||||
|
- Always: generate the idle loop.
|
||||||
|
- Library mode: also pre-generate ``library.base_clip_count``
|
||||||
|
speaking base clips.
|
||||||
|
- Reflective mode: idle loop only.
|
||||||
|
"""
|
||||||
|
if self._wan22 is None:
|
||||||
|
raise RuntimeError("set_avatar called before load_models()")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
log.info("Setting avatar: %s", image_path)
|
||||||
|
self.avatar_path = image_path
|
||||||
|
# Drop any previously cached clips so the new avatar's library
|
||||||
|
# doesn't mix with the old.
|
||||||
|
self.speaking_base_frames = []
|
||||||
|
self.idle_clip_mp4 = None
|
||||||
|
|
||||||
|
# Idle clip: short loop, neutral/listening prompt.
|
||||||
|
log.info("Generating idle clip...")
|
||||||
|
idle_frames = self._wan22.generate_i2v(
|
||||||
|
image_path=image_path,
|
||||||
|
prompt=IDLE_PROMPT,
|
||||||
|
seconds=self.cfg.library_base_clip_seconds,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
from server.video_models.muxer import frames_to_mp4_loop
|
||||||
|
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
|
||||||
|
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
|
||||||
|
|
||||||
|
# Library mode: pre-bake N speaking base clips.
|
||||||
|
if self.cfg.mode == "library":
|
||||||
|
n = self.cfg.library_base_clip_count
|
||||||
|
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
|
||||||
|
for i in range(n):
|
||||||
|
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
|
||||||
|
frames = self._wan22.generate_i2v(
|
||||||
|
image_path=image_path,
|
||||||
|
prompt=prompt,
|
||||||
|
seconds=self.cfg.library_base_clip_seconds,
|
||||||
|
seed=i + 1,
|
||||||
|
)
|
||||||
|
self.speaking_base_frames.append(frames)
|
||||||
|
log.info(" base clip %d/%d rendered", i + 1, n)
|
||||||
|
|
||||||
|
self._library_cursor = 0
|
||||||
|
|
||||||
|
def get_idle_clip(self) -> bytes | None:
|
||||||
|
return self.idle_clip_mp4
|
||||||
|
|
||||||
|
# --- Per-turn generation -------------------------------------------
|
||||||
|
|
||||||
|
def generate_speaking_clip(
|
||||||
|
self,
|
||||||
|
audio_f32: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
reply_text: str,
|
||||||
|
) -> bytes:
|
||||||
|
"""Produce a lip-synced MP4 for one assistant turn."""
|
||||||
|
if not self.is_ready():
|
||||||
|
raise RuntimeError(
|
||||||
|
"generate_speaking_clip: engine not ready "
|
||||||
|
"(avatar set? models loaded?)"
|
||||||
|
)
|
||||||
|
assert self._wan22 is not None
|
||||||
|
|
||||||
|
# 1. Source base frames.
|
||||||
|
if self.cfg.mode == "library":
|
||||||
|
base_frames = self._pick_library_frames(audio_f32, sample_rate)
|
||||||
|
else: # reflective
|
||||||
|
prompt = self._derive_prompt(reply_text)
|
||||||
|
log.info("Reflective prompt: %s", prompt[:120])
|
||||||
|
base_frames = self._wan22.generate_i2v(
|
||||||
|
image_path=self.avatar_path or "",
|
||||||
|
prompt=prompt,
|
||||||
|
seconds=self.cfg.reflective_clip_seconds,
|
||||||
|
seed=None, # random each turn
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Lip-sync the base frames to the given audio (if enabled).
|
||||||
|
if self._musetalk is not None:
|
||||||
|
synced_frames = self._musetalk.lip_sync(
|
||||||
|
frames=base_frames,
|
||||||
|
audio=audio_f32,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
fps=self.cfg.fps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
synced_frames = base_frames
|
||||||
|
|
||||||
|
# 3. Mux frames + audio into an MP4.
|
||||||
|
from server.video_models.muxer import frames_and_audio_to_mp4
|
||||||
|
return frames_and_audio_to_mp4(
|
||||||
|
frames=synced_frames,
|
||||||
|
audio=audio_f32,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
fps=self.cfg.fps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pick_library_frames(
|
||||||
|
self, audio_f32: np.ndarray, sample_rate: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Round-robin pick from the pre-baked library, clipped or looped
|
||||||
|
to roughly the audio's duration so there's no long freeze frame."""
|
||||||
|
if not self.speaking_base_frames:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Library mode has no pre-baked base clips. "
|
||||||
|
"Was set_avatar called with mode=library?"
|
||||||
|
)
|
||||||
|
frames = self.speaking_base_frames[
|
||||||
|
self._library_cursor % len(self.speaking_base_frames)
|
||||||
|
]
|
||||||
|
self._library_cursor += 1
|
||||||
|
|
||||||
|
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
|
||||||
|
if target_frames <= 0:
|
||||||
|
return frames
|
||||||
|
if target_frames <= len(frames):
|
||||||
|
return frames[:target_frames]
|
||||||
|
# Loop (with a mirror tail to soften the seam) to cover longer audio.
|
||||||
|
loops = target_frames // len(frames) + 1
|
||||||
|
extended = np.concatenate([frames] * loops, axis=0)
|
||||||
|
return extended[:target_frames]
|
||||||
|
|
||||||
|
def _derive_prompt(self, reply_text: str) -> str:
|
||||||
|
"""Template-based prompt builder for reflective mode.
|
||||||
|
|
||||||
|
Takes up to ``prompt_reply_words`` words from the start of the reply
|
||||||
|
and interpolates them into the configured template. Cheap,
|
||||||
|
deterministic, no extra LLM call.
|
||||||
|
"""
|
||||||
|
words = (reply_text or "").split()
|
||||||
|
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
|
||||||
|
if not hint:
|
||||||
|
hint = "calm and friendly"
|
||||||
|
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
# Agent guide — server/video_models/
|
||||||
|
|
||||||
|
Wrappers around 3rd-party video models. These are the trickiest files in the repo: LightX2V's internals move quickly upstream, and the Blackwell (RTX 5090 / SM120) GPU path requires several non-obvious patches layered on top. Read this before editing.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
- [wan22.py](wan22.py) — LightX2V Wan2.2-I2V A14B MoE pipeline. Supports fp8 safetensors and GGUF DIT checkpoints. Loaded once at startup, held resident; per-turn calls go through `generate_i2v` and `switch_lora`.
|
||||||
|
- [musetalk.py](musetalk.py) — MuseTalk lip-sync over base frames + TTS audio.
|
||||||
|
- [muxer.py](muxer.py) — thin ffmpeg wrappers: frames → MP4 loop, frames + audio → MP4.
|
||||||
|
|
||||||
|
Nothing here is imported unless `config.video.enabled` is true.
|
||||||
|
|
||||||
|
## LightX2V entry points (upstream API)
|
||||||
|
|
||||||
|
Use these symbols, not internal/private ones:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lightx2v.utils.set_config import set_config
|
||||||
|
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||||
|
from lightx2v.infer import init_runner
|
||||||
|
|
||||||
|
config = set_config(args) # args is an argparse.Namespace
|
||||||
|
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||||
|
runner = init_runner(config) # loads all weights — ONCE
|
||||||
|
|
||||||
|
update_input_info_from_dict(input_info, {...}) # per-turn inputs
|
||||||
|
runner.run_pipeline(input_info) # MP4 written to save_result_path
|
||||||
|
runner.switch_lora(lora_path, strength) # hot-swap
|
||||||
|
```
|
||||||
|
|
||||||
|
Keep model load out of the per-turn path — `init_runner` is expensive.
|
||||||
|
|
||||||
|
## Blackwell (SM120) patches — do not remove without testing
|
||||||
|
|
||||||
|
The GGUF pipeline works on a 5090 only because of layered patches in `wan22.py` and tuning in the LightX2V JSON configs under [configs/lightx2v/](../../configs/lightx2v/). Each patch exists because a stock upstream path segfaults or silently miscomputes on SM120.
|
||||||
|
|
||||||
|
**Dtype plumbing (GGUF path):**
|
||||||
|
|
||||||
|
- Default `DTYPE` must be `BF16` at `init_runner()` time — T5 offload buffers break if FP16 at init.
|
||||||
|
- Flip `BF16 → FP16` *after* `init_runner()`.
|
||||||
|
- Wrap T5 encoder so it runs under BF16 internally, then cast outputs `bf16 → fp16` before handing to the DIT. See `_patch_t5_dtype_for_gguf`.
|
||||||
|
- Cast VAE **both** layers: the inner `.model` via `.to(fp16)` **and** the outer `WanVAE` wrapper's `mean` / `inv_std` / `scale` tensors. Missing the wrapper tensors upcasts the latent during decode's `z/inv_std + mean`.
|
||||||
|
- DIT `pre_weight.patch_embedding.pin_weight` loads as fp32 (only `pin_bias` is fp16). Cast **and** re-pin via `.pin_memory()` — skipping re-pin segfaults during `to_cuda` H2D copy.
|
||||||
|
- `sgl_kernel`'s fp8 scaled matmul is patched to `torch._scaled_mm` in `_patch_fp8_scaled_mm_for_blackwell`.
|
||||||
|
|
||||||
|
**LightX2V JSON config (see `wan22_i2v_gguf_distill.json`):**
|
||||||
|
|
||||||
|
- `modulate_type: "torch"` — Triton `fuse_scale_shift_kernel` segfaults in `ast_to_ttir` on Triton 3.4 + SM120.
|
||||||
|
- `rope_type: "torch"` — flashinfer isn't installed.
|
||||||
|
- `self_attn_1_type` / `cross_attn_*_type`: `"torch_sdpa"` — flash_attn3 unavailable; `sageattention==1.0.6` from PyPI segfaults on Blackwell (newer requires source build).
|
||||||
|
|
||||||
|
If you add a new quant scheme or a new model_cls, create its own JSON under `configs/lightx2v/` mirroring these choices, and exercise it end-to-end via a new `tests/component/test_NN_*.py` before wiring it into the default config.
|
||||||
|
|
||||||
|
## HF download layout
|
||||||
|
|
||||||
|
`Wan-AI/Wan2.2-I2V-A14B` ships ~28 GB of bf16 DIT shards we replace with the quantised `dit_repo`. `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](wan22.py) excludes them but **keeps** `high_noise_model/*.json` and `low_noise_model/*.json` — `set_config` parses architecture params (`dim`, etc.) from those. Don't broaden the ignore pattern without checking.
|
||||||
|
|
||||||
|
Supported quant schemes live in `wan22_dit_quant_scheme`:
|
||||||
|
|
||||||
|
- `fp8-sgl` — `lightx2v/Wan2.2-Distill-Models`, two `.safetensors` files
|
||||||
|
- `gguf-Q4_K_M`, `gguf-Q8_0`, … — `QuantStack/Wan2.2-I2V-A14B-GGUF`, layout `HighNoise/…` and `LowNoise/…`
|
||||||
|
|
||||||
|
Filenames are templated at the top of `wan22.py`; update those if the upstream repos rename files.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- `tests/component/test_02_wan22_loras.py` — full pipeline load + LoRA apply
|
||||||
|
- `tests/component/test_09_gguf_generate.py` — GGUF end-to-end I2V
|
||||||
|
- `tests/component/test_10_t5_encode.py` — T5 encoder dtype path
|
||||||
|
- `tests/component/test_11_image_encode.py` — image → VAE latent
|
||||||
|
- `tests/component/test_12_dit_single_step.py` — one DIT step per expert
|
||||||
|
- `tests/component/test_13_vae_decode.py` — VAE decode → RGB
|
||||||
|
|
||||||
|
When diagnosing a Blackwell regression, run 10 → 11 → 12 → 13 in that order; the failure localises to the first failing stage.
|
||||||
|
|
||||||
|
## LoRAs
|
||||||
|
|
||||||
|
`switch_lora(path, strength)` applies; `switch_lora("", 0.0)` removes. `load_loras`/`unload_loras` in this wrapper iterate over `LoRASpec`s from config and call `switch_lora` per `target` sub-model (`high_noise`, `low_noise`, or `both`). Wrong target = silently wrong output.
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""Thin wrappers around 3rd-party video generation models.
|
||||||
|
|
||||||
|
Each submodule isolates one external dependency so the real API surface
|
||||||
|
can be updated in a single file without touching the pipeline.
|
||||||
|
|
||||||
|
Submodules:
|
||||||
|
- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V
|
||||||
|
- ``musetalk``: MuseTalk audio-driven lip-sync
|
||||||
|
- ``muxer``: ffmpeg-based frame/audio → MP4 encoding
|
||||||
|
"""
|
||||||
@@ -0,0 +1,151 @@
|
|||||||
|
"""MuseTalk audio-driven lip-sync wrapper.
|
||||||
|
|
||||||
|
MuseTalk takes a sequence of face frames + driving audio and returns a new
|
||||||
|
sequence of frames where the mouth region is animated to match the audio.
|
||||||
|
|
||||||
|
This module isolates MuseTalk's real API behind a single ``lip_sync()``
|
||||||
|
method. MuseTalk's upstream Python surface varies between forks — if the
|
||||||
|
real import path or call signature differs, update this file only.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MuseTalkEngine:
|
||||||
|
"""Thin wrapper over MuseTalk inference."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
|
||||||
|
# similar ``MuseTalkInfer`` class. Try the most common imports.
|
||||||
|
self._infer = self._load_impl(model_path)
|
||||||
|
log.info("MuseTalk engine loaded from %s", model_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_impl(model_path: str):
|
||||||
|
"""Load the MuseTalk inference implementation.
|
||||||
|
|
||||||
|
Upstream MuseTalk has no library-style entry point — it's a bundle
|
||||||
|
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
|
||||||
|
``third_party/MuseTalk`` adds package metadata but the low-level
|
||||||
|
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
|
||||||
|
modules. We import them here to verify the install succeeded; the
|
||||||
|
actual pipeline (VAE, UNet, Whisper, face detection, blending)
|
||||||
|
is wired up inside ``MuseTalkEngine.lip_sync``.
|
||||||
|
"""
|
||||||
|
resolved = model_path
|
||||||
|
if not os.path.isdir(model_path) and "/" in model_path:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
resolved = snapshot_download(repo_id=model_path)
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
|
||||||
|
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"MuseTalk Python package is not importable. "
|
||||||
|
"Check that third_party/MuseTalk was installed via "
|
||||||
|
"`pip install /opt/MuseTalk` in the Dockerfile."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Return the resolved weight path; lip_sync loads models lazily on
|
||||||
|
# first call so import-time failures don't block voice-only startup.
|
||||||
|
return {"model_path": resolved, "loaded": False}
|
||||||
|
|
||||||
|
# --- Inference ---------------------------------------------------------
|
||||||
|
|
||||||
|
def lip_sync(
|
||||||
|
self,
|
||||||
|
frames: np.ndarray,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
fps: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Return new frames with lip-sync applied to match ``audio``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
|
||||||
|
audio: float32 mono 1D audio.
|
||||||
|
sample_rate: sample rate of ``audio``.
|
||||||
|
fps: frame rate of ``frames``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
|
||||||
|
to match audio duration at ``fps``.
|
||||||
|
"""
|
||||||
|
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalise frame count to audio duration so the caller doesn't have
|
||||||
|
# to do the arithmetic.
|
||||||
|
target_t = int(round(len(audio) / sample_rate * fps))
|
||||||
|
if target_t > 0 and len(frames) != target_t:
|
||||||
|
frames = _fit_frames_to_length(frames, target_t)
|
||||||
|
|
||||||
|
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
|
||||||
|
# realtime_inference.py::Avatar.inference) needs:
|
||||||
|
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
|
||||||
|
# - face_alignment (bbox)
|
||||||
|
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
|
||||||
|
# - Whisper encoder (openai/whisper-tiny)
|
||||||
|
# - face_parsing weights
|
||||||
|
# Plus its preprocessing module has import-time side effects that
|
||||||
|
# load dwpose weights from a CWD-relative path. Turn the full
|
||||||
|
# pipeline on by extending this method once those deps are
|
||||||
|
# installed and weights are resolved — until then, callers should
|
||||||
|
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
|
||||||
|
# will skip the lip-sync pass.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MuseTalk lip-sync pipeline is not wired up yet. "
|
||||||
|
"Set config.video.musetalk.enabled=false to bypass."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
|
||||||
|
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
|
||||||
|
|
||||||
|
Repeats with a ping-pong / boomerang tail so the seam between loops is
|
||||||
|
less jarring than a hard cut back to frame 0.
|
||||||
|
"""
|
||||||
|
if target_t <= 0:
|
||||||
|
return frames
|
||||||
|
t = len(frames)
|
||||||
|
if t == target_t:
|
||||||
|
return frames
|
||||||
|
if t > target_t:
|
||||||
|
return frames[:target_t]
|
||||||
|
# Extend via ping-pong looping
|
||||||
|
extended = [frames]
|
||||||
|
total = t
|
||||||
|
flip = True
|
||||||
|
while total < target_t:
|
||||||
|
seg = frames[::-1] if flip else frames
|
||||||
|
extended.append(seg)
|
||||||
|
total += t
|
||||||
|
flip = not flip
|
||||||
|
return np.concatenate(extended, axis=0)[:target_t]
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_uint8_rgb(arr) -> np.ndarray:
|
||||||
|
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
|
||||||
|
result = np.asarray(arr)
|
||||||
|
if result.dtype != np.uint8:
|
||||||
|
if result.dtype in (np.float32, np.float64):
|
||||||
|
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
result = result.astype(np.uint8)
|
||||||
|
if result.ndim == 3:
|
||||||
|
result = result[None, ...]
|
||||||
|
return result
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
"""ffmpeg-based frame + audio → MP4 muxing.
|
||||||
|
|
||||||
|
Uses the system ``ffmpeg`` binary already installed in the Dockerfile.
|
||||||
|
No extra python dependencies beyond ``numpy``.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _ffmpeg_bin() -> str:
|
||||||
|
bin_path = shutil.which("ffmpeg")
|
||||||
|
if bin_path is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ffmpeg binary not found on PATH. It should be installed by "
|
||||||
|
"the Dockerfile (line 13). Ensure you're running inside the "
|
||||||
|
"docker image or install ffmpeg locally."
|
||||||
|
)
|
||||||
|
return bin_path
|
||||||
|
|
||||||
|
|
||||||
|
def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]:
|
||||||
|
"""Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w)."""
|
||||||
|
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
||||||
|
)
|
||||||
|
if frames.dtype != np.uint8:
|
||||||
|
frames = frames.astype(np.uint8)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(frames.tobytes())
|
||||||
|
_, h, w, _ = frames.shape
|
||||||
|
return h, w
|
||||||
|
|
||||||
|
|
||||||
|
def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None:
|
||||||
|
"""Write a float32 mono audio array to a 16-bit PCM WAV at ``path``."""
|
||||||
|
from scipy.io import wavfile # type: ignore[import-not-found]
|
||||||
|
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
|
||||||
|
int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
|
||||||
|
wavfile.write(path, sample_rate, int16)
|
||||||
|
|
||||||
|
|
||||||
|
def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes:
|
||||||
|
"""Encode ``frames`` to a silent MP4 suitable for looping playback.
|
||||||
|
|
||||||
|
Used for the idle clip: no audio track, loopable on an HTMLMediaElement
|
||||||
|
without audible seams.
|
||||||
|
"""
|
||||||
|
if frames.size == 0:
|
||||||
|
raise ValueError("frames_to_mp4_loop: empty frames")
|
||||||
|
|
||||||
|
ffmpeg = _ffmpeg_bin()
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
raw_path = os.path.join(td, "frames.raw")
|
||||||
|
out_path = os.path.join(td, "out.mp4")
|
||||||
|
h, w = _write_raw_frames(frames, raw_path)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
ffmpeg, "-y",
|
||||||
|
"-f", "rawvideo",
|
||||||
|
"-pix_fmt", "rgb24",
|
||||||
|
"-s", f"{w}x{h}",
|
||||||
|
"-r", str(fps),
|
||||||
|
"-i", raw_path,
|
||||||
|
"-an",
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "veryfast",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
"-movflags", "+faststart",
|
||||||
|
out_path,
|
||||||
|
]
|
||||||
|
log.debug("muxer idle clip: %s", " ".join(cmd))
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
with open(out_path, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def frames_and_audio_to_mp4(
|
||||||
|
frames: np.ndarray,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
fps: int,
|
||||||
|
) -> bytes:
|
||||||
|
"""Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio.
|
||||||
|
|
||||||
|
Used for per-turn speaking clips.
|
||||||
|
"""
|
||||||
|
if frames.size == 0:
|
||||||
|
raise ValueError("frames_and_audio_to_mp4: empty frames")
|
||||||
|
if audio.size == 0:
|
||||||
|
raise ValueError("frames_and_audio_to_mp4: empty audio")
|
||||||
|
|
||||||
|
ffmpeg = _ffmpeg_bin()
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
raw_path = os.path.join(td, "frames.raw")
|
||||||
|
wav_path = os.path.join(td, "audio.wav")
|
||||||
|
out_path = os.path.join(td, "out.mp4")
|
||||||
|
|
||||||
|
h, w = _write_raw_frames(frames, raw_path)
|
||||||
|
_write_wav(audio, sample_rate, wav_path)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
ffmpeg, "-y",
|
||||||
|
"-f", "rawvideo",
|
||||||
|
"-pix_fmt", "rgb24",
|
||||||
|
"-s", f"{w}x{h}",
|
||||||
|
"-r", str(fps),
|
||||||
|
"-i", raw_path,
|
||||||
|
"-i", wav_path,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "veryfast",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "128k",
|
||||||
|
"-shortest",
|
||||||
|
"-movflags", "+faststart",
|
||||||
|
out_path,
|
||||||
|
]
|
||||||
|
log.debug("muxer speaking clip: %s", " ".join(cmd))
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
with open(out_path, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ffmpeg(cmd: list[str]) -> None:
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace"))
|
||||||
|
raise
|
||||||
|
if proc.returncode != 0: # pragma: no cover
|
||||||
|
raise RuntimeError(f"ffmpeg returned {proc.returncode}")
|
||||||
@@ -0,0 +1,643 @@
|
|||||||
|
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
|
||||||
|
|
||||||
|
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||||
|
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||||
|
|
||||||
|
from lightx2v.utils.set_config import set_config
|
||||||
|
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||||
|
from lightx2v.infer import init_runner
|
||||||
|
|
||||||
|
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
|
||||||
|
config = set_config(args)
|
||||||
|
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||||
|
runner = init_runner(config) # loads all weights — done ONCE
|
||||||
|
|
||||||
|
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
|
||||||
|
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
|
||||||
|
# LoRA hot-swap:
|
||||||
|
runner.switch_lora(lora_path, strength) # swap in
|
||||||
|
runner.switch_lora("", 0.0) # remove
|
||||||
|
|
||||||
|
Model weights are loaded once at construction and held resident across turns
|
||||||
|
so reflective mode doesn't re-pay the load cost each reply.
|
||||||
|
|
||||||
|
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||||
|
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
|
||||||
|
The bf16 DIT shards are SKIPPED via
|
||||||
|
ignore_patterns — replaced by the GGUF
|
||||||
|
checkpoint from dit_repo.
|
||||||
|
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
|
||||||
|
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from server.video import LoRASpec
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
|
||||||
|
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
|
||||||
|
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
|
||||||
|
|
||||||
|
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||||
|
T5_FP8_REPO = "lightx2v/Encoders"
|
||||||
|
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||||
|
|
||||||
|
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
|
||||||
|
# tokenizer support files. We only need the latter — the GGUF from dit_repo
|
||||||
|
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
|
||||||
|
BASE_REPO_IGNORE_PATTERNS = [
|
||||||
|
"*.pt",
|
||||||
|
"diffusion_pytorch_model*.safetensors",
|
||||||
|
"assets/*",
|
||||||
|
"examples/*",
|
||||||
|
"nohup.out",
|
||||||
|
"*.md",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||||
|
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
|
||||||
|
|
||||||
|
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
|
||||||
|
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
|
||||||
|
and are missed by the structured sweep. This generic traversal catches
|
||||||
|
them. Bounded depth + visited-set to avoid cycles.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in visited or depth > 6:
|
||||||
|
return 0
|
||||||
|
visited.add(obj_id)
|
||||||
|
n = 0
|
||||||
|
for attr_name in dir(obj):
|
||||||
|
if attr_name.startswith("__"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
val = getattr(obj, attr_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||||
|
try:
|
||||||
|
setattr(obj, attr_name, val.to(torch.float16))
|
||||||
|
n += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif hasattr(val, "__dict__") and not callable(val):
|
||||||
|
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||||
|
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||||
|
|
||||||
|
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
|
||||||
|
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
|
||||||
|
including Blackwell. This patch is idempotent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import sgl_kernel # type: ignore[import-not-found]
|
||||||
|
except ImportError:
|
||||||
|
return # no sgl_kernel → fp8 T5 not in use
|
||||||
|
|
||||||
|
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
cap = torch.cuda.get_device_capability()
|
||||||
|
if cap[0] < 12:
|
||||||
|
return # only patch on Blackwell+
|
||||||
|
|
||||||
|
_orig = sgl_kernel.fp8_scaled_mm
|
||||||
|
|
||||||
|
def _torch_fp8_scaled_mm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
b_scale: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
|
||||||
|
# scale_a: scalar or (M,1)
|
||||||
|
# scale_b: scalar or (1,N)
|
||||||
|
# sgl_kernel provides scale_b as (N,1) — transpose it.
|
||||||
|
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
|
||||||
|
b_scale = b_scale.t()
|
||||||
|
# _scaled_mm requires B to be column-major (stride(0)==1).
|
||||||
|
bt = b.t().contiguous().t()
|
||||||
|
out = torch._scaled_mm(
|
||||||
|
a, bt,
|
||||||
|
scale_a=a_scale, scale_b=b_scale,
|
||||||
|
out_dtype=out_dtype, bias=bias,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
|
||||||
|
sgl_kernel._fp8_patched_for_blackwell = True
|
||||||
|
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
|
||||||
|
|
||||||
|
|
||||||
|
class Wan22Pipeline:
|
||||||
|
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
|
||||||
|
|
||||||
|
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
|
||||||
|
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
|
||||||
|
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
|
||||||
|
|
||||||
|
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||||
|
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||||
|
``generate_i2v`` runs one inference turn against the already-loaded runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_repo: str,
|
||||||
|
dit_repo: str,
|
||||||
|
config_json: str,
|
||||||
|
model_cls: str = "wan2.2",
|
||||||
|
resolution: int = 480,
|
||||||
|
fps: int = 16,
|
||||||
|
dit_quant_scheme: str = "gguf-Q8_0",
|
||||||
|
t5_quantized: bool = True,
|
||||||
|
):
|
||||||
|
self.base_repo = base_repo
|
||||||
|
self.dit_repo = dit_repo
|
||||||
|
self.config_json_template = config_json
|
||||||
|
self.model_cls = model_cls
|
||||||
|
self.resolution = resolution
|
||||||
|
self.fps = fps
|
||||||
|
self.dit_quant_scheme = dit_quant_scheme
|
||||||
|
self.t5_quantized = t5_quantized
|
||||||
|
self._applied_loras: list[LoRASpec] = []
|
||||||
|
|
||||||
|
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||||
|
if not self._is_gguf:
|
||||||
|
raise ValueError(
|
||||||
|
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
|
||||||
|
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
|
||||||
|
self._model_root = self._ensure_base_repo(base_repo)
|
||||||
|
self._dit_ckpt = self._ensure_dit_checkpoint(
|
||||||
|
dit_repo, dit_quant_scheme,
|
||||||
|
)
|
||||||
|
self._t5_fp8_ckpt = (
|
||||||
|
self._ensure_t5_fp8() if t5_quantized else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
||||||
|
self._runtime_json_path = self._build_runtime_config()
|
||||||
|
|
||||||
|
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
|
||||||
|
args = self._build_args(
|
||||||
|
model_cls=model_cls,
|
||||||
|
model_path=self._model_root,
|
||||||
|
config_json=self._runtime_json_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
|
||||||
|
# never pulls in lightx2v — tests can import this module on CPU).
|
||||||
|
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
||||||
|
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
||||||
|
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_patch_fp8_scaled_mm_for_blackwell()
|
||||||
|
|
||||||
|
# 5. Load all models under default DTYPE=BF16 so T5 (which is
|
||||||
|
# hardcoded to bf16 weights) initialises its offload buffers
|
||||||
|
# correctly. We flip to FP16 *after* init_runner completes.
|
||||||
|
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
||||||
|
model_cls, self._model_root)
|
||||||
|
self._config = set_config(args)
|
||||||
|
|
||||||
|
self._input_info_template = init_empty_input_info(
|
||||||
|
args.task, args.support_tasks
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("LightX2V init_runner — loading weights (this takes a while)...")
|
||||||
|
self._runner = init_runner(self._config)
|
||||||
|
log.info("LightX2V runner loaded; weights resident.")
|
||||||
|
|
||||||
|
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
|
||||||
|
# dequantises to fp16, and many intermediate tensors inside the
|
||||||
|
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
|
||||||
|
# is wrapped to temporarily restore BF16 during its forward.
|
||||||
|
if self._is_gguf:
|
||||||
|
os.environ["DTYPE"] = "FP16"
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
||||||
|
self._patch_t5_dtype_for_gguf()
|
||||||
|
self._patch_vae_dtype_for_gguf()
|
||||||
|
self._patch_dit_fp32_weights_for_gguf()
|
||||||
|
|
||||||
|
# --- GGUF dtype compatibility patch ----------------------------------------
|
||||||
|
|
||||||
|
def _patch_t5_dtype_for_gguf(self) -> None:
|
||||||
|
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
|
||||||
|
|
||||||
|
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
|
||||||
|
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
|
||||||
|
path breaks because intermediate tensor dtypes no longer match the bf16
|
||||||
|
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
|
||||||
|
back to bf16, then restore fp16 before the DIT runs.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
runner = self._runner
|
||||||
|
orig_run_text_encoder = runner.run_text_encoder.__func__
|
||||||
|
|
||||||
|
def bf16_text_encoder(self_runner, *args, **kwargs):
|
||||||
|
import torch
|
||||||
|
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
||||||
|
os.environ["DTYPE"] = "BF16"
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
GET_SENSITIVE_DTYPE.cache_clear()
|
||||||
|
try:
|
||||||
|
result = orig_run_text_encoder(self_runner, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Restore FP16 for the DIT / rest of the pipeline.
|
||||||
|
os.environ["DTYPE"] = "FP16"
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
GET_SENSITIVE_DTYPE.cache_clear()
|
||||||
|
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
|
||||||
|
def _to_fp16(x):
|
||||||
|
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
|
||||||
|
return x.to(torch.float16)
|
||||||
|
if isinstance(x, list):
|
||||||
|
return [_to_fp16(v) for v in x]
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
return tuple(_to_fp16(v) for v in x)
|
||||||
|
if isinstance(x, dict):
|
||||||
|
return {k: _to_fp16(v) for k, v in x.items()}
|
||||||
|
return x
|
||||||
|
return _to_fp16(result)
|
||||||
|
|
||||||
|
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
||||||
|
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
||||||
|
|
||||||
|
def _patch_vae_dtype_for_gguf(self) -> None:
|
||||||
|
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
|
||||||
|
|
||||||
|
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
|
||||||
|
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
|
||||||
|
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
|
||||||
|
VAE is a plain float model (not quantized), simply converting its
|
||||||
|
weights to fp16 avoids both input-vs-weight mismatches and the need
|
||||||
|
for any runtime dtype juggling.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
runner = self._runner
|
||||||
|
for name in ("vae_encoder", "vae_decoder"):
|
||||||
|
mod = getattr(runner, name, None)
|
||||||
|
if mod is None:
|
||||||
|
continue
|
||||||
|
inner = getattr(mod, "model", mod)
|
||||||
|
if hasattr(inner, "to"):
|
||||||
|
inner.to(dtype=torch.float16)
|
||||||
|
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
|
||||||
|
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
|
||||||
|
# the first op upcasts fp16 latents back to fp32/bf16.
|
||||||
|
for attr in ("mean", "inv_std"):
|
||||||
|
t = getattr(mod, attr, None)
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
setattr(mod, attr, t.to(torch.float16))
|
||||||
|
scale = getattr(mod, "scale", None)
|
||||||
|
if isinstance(scale, list):
|
||||||
|
mod.scale = [
|
||||||
|
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
|
||||||
|
for t in scale
|
||||||
|
]
|
||||||
|
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||||
|
|
||||||
|
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||||
|
"""Cast leftover fp32 DIT weights to fp16 (dense model).
|
||||||
|
|
||||||
|
GGUF dequantises the transformer blocks to fp16, but a handful of
|
||||||
|
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
|
||||||
|
loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||||
|
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
|
||||||
|
directly at ``runner.model`` (no MoE wrapper). After the structured
|
||||||
|
pre/post weight sweep, we also run a recursive traversal to catch
|
||||||
|
fp32 conv3d biases etc. that live outside pre/post_weight.
|
||||||
|
"""
|
||||||
|
runner = self._runner
|
||||||
|
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
|
||||||
|
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||||
|
log.info(
|
||||||
|
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
|
||||||
|
n_struct, n_extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cast_fp32_dit_weights_in_model(m) -> int:
|
||||||
|
import torch
|
||||||
|
n_cast = 0
|
||||||
|
for weights_attr in ("pre_weight", "post_weight"):
|
||||||
|
w = getattr(m, weights_attr, None)
|
||||||
|
if w is None:
|
||||||
|
continue
|
||||||
|
for sub_name in dir(w):
|
||||||
|
if sub_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
sub = getattr(w, sub_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if sub is None:
|
||||||
|
continue
|
||||||
|
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||||
|
t = getattr(sub, t_name, None)
|
||||||
|
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||||
|
casted = t.to(torch.float16)
|
||||||
|
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||||
|
try:
|
||||||
|
casted = casted.pin_memory()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
setattr(sub, t_name, casted)
|
||||||
|
n_cast += 1
|
||||||
|
return n_cast
|
||||||
|
|
||||||
|
# --- Weight provisioning -------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_base_repo(base_repo: str) -> str:
|
||||||
|
"""Return a local directory containing the Wan2.2 base support files.
|
||||||
|
|
||||||
|
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
||||||
|
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
||||||
|
shards (they're replaced by the quantised files).
|
||||||
|
"""
|
||||||
|
if os.path.isdir(base_repo):
|
||||||
|
return base_repo
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
log.info("Downloading Wan2.2 base support files from %s "
|
||||||
|
"(skipping bf16 DIT shards)...", base_repo)
|
||||||
|
return snapshot_download(
|
||||||
|
repo_id=base_repo,
|
||||||
|
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_dit_checkpoint(
|
||||||
|
dit_repo: str,
|
||||||
|
dit_quant_scheme: str,
|
||||||
|
) -> str:
|
||||||
|
"""Return the local path to the single dense GGUF DIT checkpoint."""
|
||||||
|
if not dit_repo:
|
||||||
|
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||||
|
if not dit_quant_scheme.startswith("gguf-"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Only GGUF quant schemes are supported for dense 5B Turbo "
|
||||||
|
f"(got {dit_quant_scheme!r})."
|
||||||
|
)
|
||||||
|
|
||||||
|
quant = dit_quant_scheme.replace("gguf-", "")
|
||||||
|
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
|
||||||
|
|
||||||
|
if os.path.isdir(dit_repo):
|
||||||
|
path = os.path.join(dit_repo, filename)
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
log.info("Downloading %s DIT checkpoint from %s ...",
|
||||||
|
dit_quant_scheme, dit_repo)
|
||||||
|
return hf_hub_download(repo_id=dit_repo, filename=filename)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_t5_fp8() -> str:
|
||||||
|
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
|
||||||
|
|
||||||
|
Returns the local path to the safetensors file (~6 GB).
|
||||||
|
"""
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
|
||||||
|
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||||
|
|
||||||
|
def _build_runtime_config(self) -> str:
|
||||||
|
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
||||||
|
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
# Drop editorial comments before passing to LightX2V.
|
||||||
|
cfg.pop("_comment", None)
|
||||||
|
cfg["dit_quantized_ckpt"] = self._dit_ckpt
|
||||||
|
cfg.setdefault("fps", self.fps)
|
||||||
|
|
||||||
|
# T5 fp8 quantization.
|
||||||
|
if self._t5_fp8_ckpt:
|
||||||
|
cfg["t5_quantized"] = True
|
||||||
|
cfg["t5_quant_scheme"] = "fp8-sgl"
|
||||||
|
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
|
||||||
|
|
||||||
|
tmp = tempfile.NamedTemporaryFile(
|
||||||
|
prefix="wan22_dit_", suffix=".json",
|
||||||
|
mode="w", delete=False, encoding="utf-8",
|
||||||
|
)
|
||||||
|
json.dump(cfg, tmp, indent=2)
|
||||||
|
tmp.close()
|
||||||
|
log.info("Runtime LightX2V config: %s", tmp.name)
|
||||||
|
return tmp.name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_args(
|
||||||
|
*, model_cls: str, model_path: str, config_json: str
|
||||||
|
) -> argparse.Namespace:
|
||||||
|
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
|
||||||
|
``set_config`` finds the attributes it expects. We only customize the
|
||||||
|
model/task/path fields; everything else stays at the CLI defaults.
|
||||||
|
"""
|
||||||
|
return argparse.Namespace(
|
||||||
|
seed=42,
|
||||||
|
model_cls=model_cls,
|
||||||
|
task="i2v",
|
||||||
|
support_tasks=[],
|
||||||
|
model_path=model_path,
|
||||||
|
sf_model_path=None,
|
||||||
|
config_json=config_json,
|
||||||
|
use_prompt_enhancer=False,
|
||||||
|
prompt="",
|
||||||
|
negative_prompt="",
|
||||||
|
image_path="",
|
||||||
|
last_frame_path="",
|
||||||
|
audio_path="",
|
||||||
|
image_strength="1.0",
|
||||||
|
image_frame_idx="",
|
||||||
|
src_ref_images=None,
|
||||||
|
src_video=None,
|
||||||
|
src_mask=None,
|
||||||
|
src_pose_path=None,
|
||||||
|
src_face_path=None,
|
||||||
|
src_bg_path=None,
|
||||||
|
src_mask_path=None,
|
||||||
|
pose=None,
|
||||||
|
action_path=None,
|
||||||
|
action_ckpt=None,
|
||||||
|
save_result_path=None,
|
||||||
|
return_result_tensor=False,
|
||||||
|
target_shape=[],
|
||||||
|
target_video_length=81,
|
||||||
|
aspect_ratio="",
|
||||||
|
video_path=None,
|
||||||
|
sr_ratio=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- LoRA --------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||||
|
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
|
||||||
|
|
||||||
|
Dense has a single DIT (no MoE experts), so ``target`` must be
|
||||||
|
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
|
||||||
|
so ``switch_lora`` would crash — we use the dynamic-apply path that
|
||||||
|
merges LoRAs during GGUF dequant.
|
||||||
|
"""
|
||||||
|
if not specs:
|
||||||
|
return
|
||||||
|
|
||||||
|
resolved: list[tuple["LoRASpec", str]] = []
|
||||||
|
for spec in specs:
|
||||||
|
if spec.target != "both":
|
||||||
|
raise ValueError(
|
||||||
|
f"Dense 5B Turbo has a single DIT; LoRA target must be "
|
||||||
|
f"'both' (got {spec.target!r})."
|
||||||
|
)
|
||||||
|
local_path = self._resolve_lora_path(spec.path)
|
||||||
|
log.info(" LoRA %s → strength=%.2f (%s)",
|
||||||
|
spec.name or spec.path, spec.weight, local_path)
|
||||||
|
resolved.append((spec, local_path))
|
||||||
|
|
||||||
|
lora_cfgs = [
|
||||||
|
{"path": local_path, "strength": spec.weight}
|
||||||
|
for spec, local_path in resolved
|
||||||
|
]
|
||||||
|
self._runner.set_config({
|
||||||
|
"lora_configs": lora_cfgs,
|
||||||
|
"lora_dynamic_apply": True,
|
||||||
|
})
|
||||||
|
self._applied_loras = list(specs)
|
||||||
|
|
||||||
|
def unload_loras(self) -> None:
|
||||||
|
"""Remove all currently applied LoRAs."""
|
||||||
|
if not self._applied_loras:
|
||||||
|
return
|
||||||
|
self._runner.set_config({
|
||||||
|
"lora_configs": None,
|
||||||
|
"lora_dynamic_apply": False,
|
||||||
|
})
|
||||||
|
self._applied_loras = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_lora_path(path: str) -> str:
|
||||||
|
"""Resolve a LoRA path. Supports:
|
||||||
|
- Absolute/relative local paths (returned as-is if the file exists)
|
||||||
|
- ``repo_id:filename`` HuggingFace references
|
||||||
|
"""
|
||||||
|
if os.path.isfile(path):
|
||||||
|
return path
|
||||||
|
if ":" in path and not path.startswith(("/", "./")):
|
||||||
|
repo_id, filename = path.split(":", 1)
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||||
|
return path
|
||||||
|
|
||||||
|
# --- Inference ---------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_i2v(
|
||||||
|
self,
|
||||||
|
image_path: str,
|
||||||
|
prompt: str,
|
||||||
|
seconds: int,
|
||||||
|
seed: int | None = None,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Run image-to-video inference and return decoded frames.
|
||||||
|
|
||||||
|
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
|
||||||
|
"""
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(0, 2**31 - 1)
|
||||||
|
|
||||||
|
# Wan2.2 target_video_length is "frames including the conditioning
|
||||||
|
# frame", so N seconds → N*fps + 1.
|
||||||
|
target_frames = seconds * self.fps + 1
|
||||||
|
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||||
|
out_path = tf.name
|
||||||
|
try:
|
||||||
|
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s",
|
||||||
|
prompt[:80], seconds, seed, out_path)
|
||||||
|
update_input_info_from_dict(
|
||||||
|
self._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": seed,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"image_path": image_path,
|
||||||
|
"save_result_path": out_path,
|
||||||
|
"target_video_length": target_frames,
|
||||||
|
"return_result_tensor": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._runner.run_pipeline(self._input_info_template)
|
||||||
|
return _read_mp4_to_frames(out_path)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.remove(out_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# --- MP4 decoding helper ------------------------------------------------------
|
||||||
|
|
||||||
|
def _read_mp4_to_frames(path: str) -> np.ndarray:
|
||||||
|
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
|
||||||
|
try:
|
||||||
|
import imageio.v3 as iio # type: ignore[import-not-found]
|
||||||
|
frames = iio.imread(path, plugin="pyav")
|
||||||
|
arr = np.asarray(frames)
|
||||||
|
if arr.ndim == 3:
|
||||||
|
arr = arr[None, ...]
|
||||||
|
return arr.astype(np.uint8)
|
||||||
|
except Exception as e: # pragma: no cover - fallback path
|
||||||
|
log.warning("imageio decode failed (%s); falling back to cv2", e)
|
||||||
|
import cv2 # type: ignore[import-not-found]
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
frames: list[np.ndarray] = []
|
||||||
|
while True:
|
||||||
|
ok, frame = cap.read()
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
|
cap.release()
|
||||||
|
if not frames:
|
||||||
|
raise RuntimeError(f"Failed to decode any frames from {path}")
|
||||||
|
return np.stack(frames, axis=0).astype(np.uint8)
|
||||||
+159
@@ -18,9 +18,18 @@ let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to
|
|||||||
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
|
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
|
||||||
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
|
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
|
||||||
|
|
||||||
|
// --- Video mode state ---
|
||||||
|
let videoModeEnabled = false; // true when server has video engine active AND ready
|
||||||
|
let videoModeName = "off"; // "off" | "library" | "reflective"
|
||||||
|
let idleClipUrl = null; // URL string (server-served) or null
|
||||||
|
let pendingSpeakingClipMeta = null; // {chunk_id, duration_ms, text} waiting for MP4 binary
|
||||||
|
let currentSpeakingClipBlobUrl = null;
|
||||||
|
|
||||||
const chatArea = document.getElementById("chat-area");
|
const chatArea = document.getElementById("chat-area");
|
||||||
const statusBadge = document.getElementById("status-badge");
|
const statusBadge = document.getElementById("status-badge");
|
||||||
const micBtn = document.getElementById("mic-btn");
|
const micBtn = document.getElementById("mic-btn");
|
||||||
|
const avatarVideo = document.getElementById("avatar-video");
|
||||||
|
const stageEl = document.getElementById("stage");
|
||||||
|
|
||||||
// --- WebSocket ---
|
// --- WebSocket ---
|
||||||
|
|
||||||
@@ -44,7 +53,18 @@ function connectWS() {
|
|||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
if (event.data instanceof ArrayBuffer) {
|
if (event.data instanceof ArrayBuffer) {
|
||||||
|
// In video mode, the next binary frame after a "speaking_clip"
|
||||||
|
// envelope is an MP4 blob; otherwise it's a PCM audio chunk.
|
||||||
|
if (pendingSpeakingClipMeta) {
|
||||||
|
const meta = pendingSpeakingClipMeta;
|
||||||
|
pendingSpeakingClipMeta = null;
|
||||||
|
playSpeakingClip(event.data, meta);
|
||||||
|
} else if (videoModeEnabled) {
|
||||||
|
// Video mode is active but we didn't get a speaking_clip envelope
|
||||||
|
// first — ignore raw PCM so we don't double-play audio.
|
||||||
|
} else {
|
||||||
playAudioChunk(event.data);
|
playAudioChunk(event.data);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
handleJSON(JSON.parse(event.data));
|
handleJSON(JSON.parse(event.data));
|
||||||
}
|
}
|
||||||
@@ -59,6 +79,7 @@ function handleJSON(msg) {
|
|||||||
|
|
||||||
case "interrupt":
|
case "interrupt":
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
|
stopSpeakingClip();
|
||||||
// Finalize with interrupted marker — text already reflects only what was heard
|
// Finalize with interrupted marker — text already reflects only what was heard
|
||||||
finalizeAssistantMessage(true);
|
finalizeAssistantMessage(true);
|
||||||
break;
|
break;
|
||||||
@@ -80,6 +101,141 @@ function handleJSON(msg) {
|
|||||||
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
|
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case "video_mode":
|
||||||
|
// Sent once on WS open. Toggles the video element + speaking-clip path.
|
||||||
|
applyVideoModeState(msg);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "speaking_clip":
|
||||||
|
// Envelope preceding an MP4 binary frame with the full turn.
|
||||||
|
pendingSpeakingClipMeta = {
|
||||||
|
chunk_id: msg.chunk_id,
|
||||||
|
duration_ms: msg.duration_ms,
|
||||||
|
text: msg.text,
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Video mode ------------------------------------------------------------
|
||||||
|
|
||||||
|
function applyVideoModeState(msg) {
|
||||||
|
videoModeEnabled = !!msg.enabled && !!msg.ready;
|
||||||
|
videoModeName = msg.mode || "off";
|
||||||
|
idleClipUrl = msg.idle_clip_url || null;
|
||||||
|
refreshStage();
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshStage() {
|
||||||
|
if (videoModeEnabled && idleClipUrl) {
|
||||||
|
stageEl.classList.add("active");
|
||||||
|
if (avatarVideo.src !== location.origin + idleClipUrl) {
|
||||||
|
avatarVideo.src = idleClipUrl;
|
||||||
|
avatarVideo.loop = true;
|
||||||
|
avatarVideo.muted = true;
|
||||||
|
avatarVideo.play().catch(() => {});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stageEl.classList.remove("active");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function playSpeakingClip(arrayBuffer, meta) {
|
||||||
|
// Replace the idle loop with the speaking clip.
|
||||||
|
stopSpeakingClip();
|
||||||
|
const blob = new Blob([arrayBuffer], { type: "video/mp4" });
|
||||||
|
currentSpeakingClipBlobUrl = URL.createObjectURL(blob);
|
||||||
|
|
||||||
|
avatarVideo.loop = false;
|
||||||
|
avatarVideo.muted = false;
|
||||||
|
avatarVideo.src = currentSpeakingClipBlobUrl;
|
||||||
|
|
||||||
|
// Show the full reply text now — the MP4 plays it in one shot so there's
|
||||||
|
// no per-chunk sync to do.
|
||||||
|
if (meta && meta.text) {
|
||||||
|
appendAssistantText(meta.text);
|
||||||
|
}
|
||||||
|
isPlaying = true;
|
||||||
|
|
||||||
|
avatarVideo.onended = () => {
|
||||||
|
isPlaying = false;
|
||||||
|
finalizeAssistantMessage(false);
|
||||||
|
// Return to idle loop.
|
||||||
|
if (idleClipUrl) {
|
||||||
|
avatarVideo.loop = true;
|
||||||
|
avatarVideo.muted = true;
|
||||||
|
avatarVideo.src = idleClipUrl;
|
||||||
|
avatarVideo.play().catch(() => {});
|
||||||
|
}
|
||||||
|
if (currentSpeakingClipBlobUrl) {
|
||||||
|
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
|
||||||
|
currentSpeakingClipBlobUrl = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
avatarVideo.play().catch((e) => {
|
||||||
|
console.error("speaking clip play failed:", e);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopSpeakingClip() {
|
||||||
|
if (!currentSpeakingClipBlobUrl) return;
|
||||||
|
try {
|
||||||
|
avatarVideo.pause();
|
||||||
|
} catch (_) {}
|
||||||
|
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
|
||||||
|
currentSpeakingClipBlobUrl = null;
|
||||||
|
if (idleClipUrl) {
|
||||||
|
avatarVideo.loop = true;
|
||||||
|
avatarVideo.muted = true;
|
||||||
|
avatarVideo.src = idleClipUrl;
|
||||||
|
avatarVideo.play().catch(() => {});
|
||||||
|
}
|
||||||
|
isPlaying = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function uploadAvatar() {
|
||||||
|
const fileInput = document.getElementById("avatar-file");
|
||||||
|
const status = document.getElementById("avatar-status");
|
||||||
|
if (!fileInput.files || !fileInput.files[0]) {
|
||||||
|
status.textContent = "Pick an image first.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status.textContent = "Uploading and rendering idle clip (this takes a while)...";
|
||||||
|
const fd = new FormData();
|
||||||
|
fd.append("image", fileInput.files[0]);
|
||||||
|
try {
|
||||||
|
const resp = await fetch("/api/set-avatar", { method: "POST", body: fd });
|
||||||
|
if (!resp.ok) throw new Error(await resp.text());
|
||||||
|
const data = await resp.json();
|
||||||
|
idleClipUrl = data.idle_clip_url + "?t=" + Date.now(); // cache-bust
|
||||||
|
videoModeEnabled = true;
|
||||||
|
videoModeName = data.mode || videoModeName;
|
||||||
|
refreshStage();
|
||||||
|
status.textContent = "Avatar ready (" + data.mode + ")";
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err);
|
||||||
|
status.textContent = "Failed: " + err.message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function applyVideoMode() {
|
||||||
|
const sel = document.getElementById("video-mode-select");
|
||||||
|
const status = document.getElementById("avatar-status");
|
||||||
|
const fd = new FormData();
|
||||||
|
fd.append("mode", sel.value);
|
||||||
|
try {
|
||||||
|
const resp = await fetch("/api/set-video-mode", { method: "POST", body: fd });
|
||||||
|
if (!resp.ok) throw new Error(await resp.text());
|
||||||
|
const data = await resp.json();
|
||||||
|
videoModeName = data.mode;
|
||||||
|
if (data.mode === "off") {
|
||||||
|
videoModeEnabled = false;
|
||||||
|
stageEl.classList.remove("active");
|
||||||
|
}
|
||||||
|
status.textContent = "Mode: " + data.mode + (data.note ? " — " + data.note : "");
|
||||||
|
} catch (err) {
|
||||||
|
status.textContent = "Failed: " + err.message;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,6 +431,7 @@ async function startMic() {
|
|||||||
if (bargeInCount >= BARGE_IN_FRAMES) {
|
if (bargeInCount >= BARGE_IN_FRAMES) {
|
||||||
// User is speaking over the assistant - interrupt
|
// User is speaking over the assistant - interrupt
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
|
stopSpeakingClip();
|
||||||
const msg = { type: "interrupt" };
|
const msg = { type: "interrupt" };
|
||||||
if (lastDisplayedChunkId >= 0) {
|
if (lastDisplayedChunkId >= 0) {
|
||||||
msg.last_chunk_id = lastDisplayedChunkId;
|
msg.last_chunk_id = lastDisplayedChunkId;
|
||||||
@@ -353,3 +510,5 @@ async function applyVoice() {
|
|||||||
// Expose to HTML onclick
|
// Expose to HTML onclick
|
||||||
window.toggleMic = toggleMic;
|
window.toggleMic = toggleMic;
|
||||||
window.applyVoice = applyVoice;
|
window.applyVoice = applyVoice;
|
||||||
|
window.uploadAvatar = uploadAvatar;
|
||||||
|
window.applyVideoMode = applyVideoMode;
|
||||||
|
|||||||
@@ -12,6 +12,17 @@
|
|||||||
<span id="status-badge">Disconnected</span>
|
<span id="status-badge">Disconnected</span>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
|
<div id="stage">
|
||||||
|
<video
|
||||||
|
id="avatar-video"
|
||||||
|
autoplay
|
||||||
|
muted
|
||||||
|
loop
|
||||||
|
playsinline
|
||||||
|
preload="auto"
|
||||||
|
></video>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div id="chat-area"></div>
|
<div id="chat-area"></div>
|
||||||
|
|
||||||
<details id="voice-panel">
|
<details id="voice-panel">
|
||||||
@@ -40,6 +51,27 @@
|
|||||||
</div>
|
</div>
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details id="avatar-panel">
|
||||||
|
<summary>Avatar / Video</summary>
|
||||||
|
<div class="panel-content">
|
||||||
|
<label>
|
||||||
|
Avatar image
|
||||||
|
<input type="file" id="avatar-file" accept="image/*" />
|
||||||
|
</label>
|
||||||
|
<button id="upload-avatar-btn" onclick="uploadAvatar()">Upload</button>
|
||||||
|
<label>
|
||||||
|
Mode
|
||||||
|
<select id="video-mode-select">
|
||||||
|
<option value="off">Off</option>
|
||||||
|
<option value="library">Library (pre-baked)</option>
|
||||||
|
<option value="reflective" selected>Reflective (per-turn)</option>
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
<button id="apply-mode-btn" onclick="applyVideoMode()">Apply mode</button>
|
||||||
|
<span id="avatar-status"></span>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
<div id="controls">
|
<div id="controls">
|
||||||
<button id="mic-btn" onclick="toggleMic()">🎤</button>
|
<button id="mic-btn" onclick="toggleMic()">🎤</button>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
+39
-4
@@ -52,6 +52,28 @@ header h1 {
|
|||||||
color: #a78bfa;
|
color: #a78bfa;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#stage {
|
||||||
|
display: none; /* toggled on when video mode is enabled */
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
padding: 16px 24px 0;
|
||||||
|
background: #0a0a0a;
|
||||||
|
}
|
||||||
|
|
||||||
|
#stage.active {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
#avatar-video {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 480px;
|
||||||
|
aspect-ratio: 16 / 9;
|
||||||
|
background: #000;
|
||||||
|
border-radius: 12px;
|
||||||
|
object-fit: cover;
|
||||||
|
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
|
||||||
|
}
|
||||||
|
|
||||||
#chat-area {
|
#chat-area {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
@@ -130,21 +152,34 @@ header h1 {
|
|||||||
50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); }
|
50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); }
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Voice clone panel */
|
/* Voice + avatar panels */
|
||||||
#voice-panel {
|
#voice-panel,
|
||||||
|
#avatar-panel {
|
||||||
padding: 12px 24px;
|
padding: 12px 24px;
|
||||||
border-top: 1px solid #222;
|
border-top: 1px solid #222;
|
||||||
background: #0a0a0a;
|
background: #0a0a0a;
|
||||||
}
|
}
|
||||||
|
|
||||||
#voice-panel summary {
|
#voice-panel select,
|
||||||
|
#avatar-panel select {
|
||||||
|
background: #1a1a1a;
|
||||||
|
border: 1px solid #333;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 6px 10px;
|
||||||
|
color: #e0e0e0;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#voice-panel summary,
|
||||||
|
#avatar-panel summary {
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
color: #888;
|
color: #888;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
#voice-panel .panel-content {
|
#voice-panel .panel-content,
|
||||||
|
#avatar-panel .panel-content {
|
||||||
margin-top: 12px;
|
margin-top: 12px;
|
||||||
display: flex;
|
display: flex;
|
||||||
gap: 12px;
|
gap: 12px;
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
# Voice-chat tests
|
||||||
|
|
||||||
|
Two tiers.
|
||||||
|
|
||||||
|
## Unit tests — fast, GPU-free
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m pytest tests/unit -v
|
||||||
|
```
|
||||||
|
|
||||||
|
These exercise pure logic: config parsing, prompt derivation, LoRA spec
|
||||||
|
parsing, frame-length fitting, library round-robin selection, the
|
||||||
|
pipeline's video branch, and ffmpeg mux argument shaping. They do not
|
||||||
|
touch CUDA, Wan2.2, MuseTalk, or a real ffmpeg binary. Safe to run on
|
||||||
|
Windows, outside Docker, without any models installed.
|
||||||
|
|
||||||
|
Current unit files:
|
||||||
|
|
||||||
|
- `test_video_config.py` — `VideoConfig.from_dict` round-trip, LoRA target validation
|
||||||
|
- `test_video_engine_logic.py` — prompt derivation, library cursor, frame fitting
|
||||||
|
- `test_pipeline_video_branch.py` — pipeline takes the video path iff engine is ready
|
||||||
|
- `test_musetalk_fit_frames.py` — frame-length adjustment to match audio duration
|
||||||
|
- `test_muxer_ffmpeg.py` — ffmpeg command construction
|
||||||
|
|
||||||
|
## Component tests — slow, GPU-required, run inside Docker
|
||||||
|
|
||||||
|
Each script in `tests/component/` exercises one subsystem end-to-end
|
||||||
|
against the real models. The numbered prefix reflects the implementation
|
||||||
|
phase each script gates, and also serves as a reasonable run order when
|
||||||
|
debugging a fresh environment:
|
||||||
|
|
||||||
|
| Script | Phase | Tests |
|
||||||
|
|---|---|---|
|
||||||
|
| `test_01_video_skeleton.py` | 1 | VideoEngine loads, config gate respected |
|
||||||
|
| `test_02_wan22_loras.py` | 2 | Wan2.2 pipeline loads, LoRA stack applies |
|
||||||
|
| `test_03_idle_clip.py` | 3 | `set_avatar` → idle MP4, written to disk for eyeballing |
|
||||||
|
| `test_04_library_prebake.py` | 4 | library mode pre-bakes N base clips |
|
||||||
|
| `test_05_musetalk_lipsync.py` | 5 | MuseTalk lip-sync on library frames + ffmpeg mux |
|
||||||
|
| `test_06_reflective.py` | 6 | reflective mode: fresh Wan2.2 per reply |
|
||||||
|
| `test_07_endpoints.py` | 7 | HTTP endpoints return sane responses |
|
||||||
|
| `test_08_lora_reload.py` | 8 | `/api/reload-loras` swaps LoRAs live |
|
||||||
|
| `test_09_gguf_generate.py` | 9 | GGUF-quantised DIT end-to-end I2V generation |
|
||||||
|
| `test_10_t5_encode.py` | 10 | T5 encoder (optionally fp8-quantised) on CUDA |
|
||||||
|
| `test_11_image_encode.py` | 11 | Avatar image → VAE latent path |
|
||||||
|
| `test_12_dit_single_step.py` | 12 | Single DIT step on the loaded expert(s) |
|
||||||
|
| `test_13_vae_decode.py` | 13 | VAE decode back to RGB frames |
|
||||||
|
|
||||||
|
Tests 09-13 are focused on the GGUF + Blackwell (SM120) path and are how
|
||||||
|
new quant schemes / attention backends get validated before wiring them
|
||||||
|
into the full pipeline.
|
||||||
|
|
||||||
|
Run one:
|
||||||
|
|
||||||
|
```
|
||||||
|
# Inside the container:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
|
||||||
|
```
|
||||||
|
|
||||||
|
Run all (slow, ~20+ minutes on a 5090):
|
||||||
|
|
||||||
|
```
|
||||||
|
docker compose exec voice-chat python -m tests.component.run_all
|
||||||
|
```
|
||||||
|
|
||||||
|
Each component script writes its artifacts (MP4s, PNG frame dumps, logs)
|
||||||
|
to `tests/component/_out/` so you can visually inspect results. That
|
||||||
|
directory is gitignored.
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""Shared utilities for component tests.
|
||||||
|
|
||||||
|
Component tests run inside the Docker image against real GPU models. They
|
||||||
|
write their output artefacts (MP4s, PNGs, logs) to ``_out/`` so you can
|
||||||
|
visually inspect results.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
OUT_DIR = os.path.join(os.path.dirname(__file__), "_out")
|
||||||
|
os.makedirs(OUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# A tiny 256x256 portrait PNG lives next to the component tests so tests
|
||||||
|
# don't need a user-supplied file. If it's missing we synthesise one on
|
||||||
|
# the fly.
|
||||||
|
SAMPLE_AVATAR = os.path.join(os.path.dirname(__file__), "sample_avatar.png")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
||||||
|
stream=sys.stdout,
|
||||||
|
)
|
||||||
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_sample_avatar() -> str:
|
||||||
|
"""Guarantee a usable avatar image exists. Returns its path."""
|
||||||
|
if os.path.isfile(SAMPLE_AVATAR):
|
||||||
|
return SAMPLE_AVATAR
|
||||||
|
# Synthesise a simple gradient PNG as a last resort (won't look like a
|
||||||
|
# person but is valid input for Wan2.2 so the pipeline doesn't fail).
|
||||||
|
try:
|
||||||
|
from PIL import Image # type: ignore[import-not-found]
|
||||||
|
except ImportError:
|
||||||
|
import imageio.v3 as iio # type: ignore[import-not-found]
|
||||||
|
arr = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||||
|
for y in range(256):
|
||||||
|
arr[y, :, 0] = y
|
||||||
|
arr[y, :, 1] = 255 - y
|
||||||
|
arr[y, :, 2] = 128
|
||||||
|
iio.imwrite(SAMPLE_AVATAR, arr)
|
||||||
|
return SAMPLE_AVATAR
|
||||||
|
|
||||||
|
arr = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||||
|
for y in range(256):
|
||||||
|
arr[y, :, 0] = y
|
||||||
|
arr[y, :, 1] = 255 - y
|
||||||
|
arr[y, :, 2] = 128
|
||||||
|
Image.fromarray(arr).save(SAMPLE_AVATAR)
|
||||||
|
return SAMPLE_AVATAR
|
||||||
|
|
||||||
|
|
||||||
|
def write_bytes(name: str, data: bytes) -> str:
|
||||||
|
"""Write an artefact to _out/<name> and return the full path."""
|
||||||
|
path = os.path.join(OUT_DIR, name)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(data)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def synth_tone(seconds: float, sample_rate: int = 24000, freq: float = 220.0) -> np.ndarray:
|
||||||
|
"""Return a float32 sine tone usable as stand-in TTS audio."""
|
||||||
|
t = np.arange(int(seconds * sample_rate), dtype=np.float32) / sample_rate
|
||||||
|
return (0.2 * np.sin(2 * np.pi * freq * t)).astype(np.float32)
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Run every component test in order. Stops at first failure.
|
||||||
|
|
||||||
|
docker compose exec voice-chat python -m tests.component.run_all
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
SCRIPTS = [
|
||||||
|
"tests.component.test_01_video_skeleton",
|
||||||
|
"tests.component.test_02_wan22_loras",
|
||||||
|
"tests.component.test_03_idle_clip",
|
||||||
|
"tests.component.test_04_library_prebake",
|
||||||
|
"tests.component.test_05_musetalk_lipsync",
|
||||||
|
"tests.component.test_06_reflective",
|
||||||
|
"tests.component.test_07_endpoints",
|
||||||
|
"tests.component.test_08_lora_reload",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
failed: list[str] = []
|
||||||
|
for name in SCRIPTS:
|
||||||
|
print(f"\n{'=' * 70}\nRUNNING: {name}\n{'=' * 70}")
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(name)
|
||||||
|
mod.run()
|
||||||
|
except SystemExit as e:
|
||||||
|
if e.code:
|
||||||
|
print(f"FAILED: {name} (exit {e.code})")
|
||||||
|
failed.append(name)
|
||||||
|
break # hard-stop on failure
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
failed.append(name)
|
||||||
|
break
|
||||||
|
if failed:
|
||||||
|
print(f"\n{len(failed)} failed: {failed}")
|
||||||
|
return 1
|
||||||
|
print("\nALL COMPONENT TESTS PASSED")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
@@ -0,0 +1,69 @@
|
|||||||
|
"""Phase 1 component test: VideoEngine skeleton + config gate.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- ``ModelManager`` can be imported and constructed.
|
||||||
|
- When ``config.video.enabled=false``, ``_load_video`` skips and leaves
|
||||||
|
``video_engine=None`` (existing voice path unaffected).
|
||||||
|
- When ``config.video.enabled=true``, a ``VideoEngine`` instance is created
|
||||||
|
and ``is_ready()`` returns False (no models loaded yet).
|
||||||
|
|
||||||
|
Does NOT load Wan2.2 or MuseTalk — this test is safe to run on any machine
|
||||||
|
with the python deps installed (no GPU needed).
|
||||||
|
|
||||||
|
Run inside Docker:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_01_video_skeleton
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from server.models import ModelManager
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
|
||||||
|
from tests.component._common import get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_01")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
# --- disabled path ---
|
||||||
|
log.info("[case 1] config.video.enabled=False → engine skipped")
|
||||||
|
mgr = ModelManager()
|
||||||
|
# Monkey-patch the config module to simulate disabled
|
||||||
|
import server.config as cfgmod
|
||||||
|
original = cfgmod.config
|
||||||
|
cfgmod.config = {"video": {"enabled": False}, **{k: v for k, v in original.items() if k != "video"}}
|
||||||
|
try:
|
||||||
|
mgr._load_video()
|
||||||
|
assert mgr.video_engine is None, "video_engine should be None when disabled"
|
||||||
|
log.info(" PASS: video_engine is None")
|
||||||
|
finally:
|
||||||
|
cfgmod.config = original
|
||||||
|
|
||||||
|
# --- enabled path (no models loaded) ---
|
||||||
|
log.info("[case 2] config.video.enabled=True → engine created, not ready")
|
||||||
|
mgr2 = ModelManager()
|
||||||
|
cfgmod.config = {
|
||||||
|
**original,
|
||||||
|
"video": {"enabled": True, "mode": "reflective", "loras": []},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
mgr2._load_video()
|
||||||
|
assert mgr2.video_engine is not None, "video_engine should be created"
|
||||||
|
assert isinstance(mgr2.video_engine, VideoEngine)
|
||||||
|
assert mgr2.video_engine.is_ready() is False
|
||||||
|
log.info(" PASS: engine=%s, ready=%s",
|
||||||
|
type(mgr2.video_engine).__name__, mgr2.video_engine.is_ready())
|
||||||
|
finally:
|
||||||
|
cfgmod.config = original
|
||||||
|
|
||||||
|
log.info("ALL PASSED")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
run()
|
||||||
|
sys.exit(0)
|
||||||
|
except AssertionError as e:
|
||||||
|
log.error("FAILED: %s", e)
|
||||||
|
sys.exit(1)
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
"""Phase 2 component test: dense Wan2.2-TI2V-5B-Turbo pipeline + LoRA stacking.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
|
||||||
|
set_config -> init_runner flow).
|
||||||
|
- ``load_loras`` / ``unload_loras`` survive with any user LoRAs at
|
||||||
|
``/cache/loras/*.safetensors`` (target='both', dense single DIT).
|
||||||
|
|
||||||
|
Supports any GGUF quant published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||||
|
Set ``DIT_QUANT`` to switch (default: ``gguf-Q8_0``).
|
||||||
|
|
||||||
|
DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
|
||||||
|
python -m tests.component.test_02_wan22_loras
|
||||||
|
|
||||||
|
Requires GPU and a first-run download of the base repo + GGUF DIT.
|
||||||
|
If LightX2V isn't installed the test is skipped.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.component._common import get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_02")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Wan22Pipeline import failed: %s", e)
|
||||||
|
log.warning("SKIP: phase 2 deps not installed")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
from server.video import LoRASpec
|
||||||
|
|
||||||
|
log.info("[case 1] Instantiate Wan22Pipeline "
|
||||||
|
"(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
|
||||||
|
try:
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
|
||||||
|
log.error("Check: LightX2V install, HF cache at /cache/huggingface, "
|
||||||
|
"VRAM headroom, and that %s exists inside the container.",
|
||||||
|
CONFIG_JSON)
|
||||||
|
sys.exit(2)
|
||||||
|
log.info(" PASS: pipeline constructed")
|
||||||
|
|
||||||
|
# --- LoRAs ---
|
||||||
|
log.info("[case 2] load_loras with empty list -> no-op")
|
||||||
|
pipe.load_loras([])
|
||||||
|
log.info(" PASS")
|
||||||
|
|
||||||
|
lora_files = sorted(glob.glob("/cache/loras/*.safetensors"))
|
||||||
|
if not lora_files:
|
||||||
|
log.warning("SKIP: no LoRA files found in /cache/loras/")
|
||||||
|
log.info("ALL PASSED (partial — LoRA cases skipped)")
|
||||||
|
return
|
||||||
|
|
||||||
|
lora_path = lora_files[0]
|
||||||
|
log.info("[case 3] load_loras with one 5B-compatible LoRA (%s)", lora_path)
|
||||||
|
specs = [
|
||||||
|
LoRASpec(
|
||||||
|
path=lora_path,
|
||||||
|
weight=1.0,
|
||||||
|
target="both",
|
||||||
|
name=os.path.basename(lora_path),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
pipe.load_loras(specs)
|
||||||
|
except Exception as e:
|
||||||
|
log.error("FAIL: load_loras raised: %s", e)
|
||||||
|
log.error("Check: LoRA checkpoint shape matches dense 5B DIT.")
|
||||||
|
sys.exit(3)
|
||||||
|
log.info(" PASS: LoRAs applied")
|
||||||
|
|
||||||
|
log.info("[case 4] unload_loras")
|
||||||
|
try:
|
||||||
|
pipe.unload_loras()
|
||||||
|
except Exception as e:
|
||||||
|
log.error("FAIL: unload_loras raised: %s", e)
|
||||||
|
sys.exit(4)
|
||||||
|
log.info(" PASS")
|
||||||
|
|
||||||
|
log.info("ALL PASSED")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""Phase 3 component test: avatar upload → idle clip generation.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- ``VideoEngine.load_models()`` + ``set_avatar(image)`` produces a non-empty
|
||||||
|
idle MP4 blob.
|
||||||
|
- The blob decodes as a valid MP4 (ftyp header).
|
||||||
|
|
||||||
|
Writes the idle clip to ``tests/component/_out/phase3_idle.mp4`` so you can
|
||||||
|
inspect it visually.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||||
|
|
||||||
|
log = get_logger("test_03")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
avatar_path = ensure_sample_avatar()
|
||||||
|
log.info("Using avatar: %s", avatar_path)
|
||||||
|
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"mode": "reflective", # reflective skips the library prebake
|
||||||
|
"resolution": 480,
|
||||||
|
"fps": 16,
|
||||||
|
"library": {"base_clip_count": 0, "base_clip_seconds": 3},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
engine = VideoEngine(cfg)
|
||||||
|
|
||||||
|
log.info("Loading models (Wan2.2 + MuseTalk)...")
|
||||||
|
try:
|
||||||
|
engine.load_models()
|
||||||
|
except Exception as e:
|
||||||
|
log.error("FAIL: load_models raised: %s", e)
|
||||||
|
sys.exit(2)
|
||||||
|
log.info("Models loaded.")
|
||||||
|
|
||||||
|
log.info("Generating idle clip for avatar...")
|
||||||
|
try:
|
||||||
|
engine.set_avatar(avatar_path)
|
||||||
|
except Exception as e:
|
||||||
|
log.error("FAIL: set_avatar raised: %s", e)
|
||||||
|
sys.exit(3)
|
||||||
|
|
||||||
|
idle = engine.get_idle_clip()
|
||||||
|
assert idle is not None and len(idle) > 0, "idle clip is empty"
|
||||||
|
assert idle[4:8] == b"ftyp", "idle clip is not a valid MP4"
|
||||||
|
|
||||||
|
out_path = write_bytes("phase3_idle.mp4", idle)
|
||||||
|
log.info("PASS: idle clip written to %s (%d bytes)", out_path, len(idle))
|
||||||
|
|
||||||
|
assert engine.is_ready() is True
|
||||||
|
log.info(" engine.is_ready() = True (avatar + models present)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""Phase 4 component test: library mode pre-bake of speaking-base clips.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- ``set_avatar`` under ``mode=library`` populates ``speaking_base_frames``
|
||||||
|
with ``library_base_clip_count`` entries.
|
||||||
|
- Each cached entry has shape ``[T, H, W, 3]`` uint8.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_04_library_prebake
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_04")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
avatar_path = ensure_sample_avatar()
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"mode": "library",
|
||||||
|
"resolution": 480,
|
||||||
|
"fps": 16,
|
||||||
|
"library": {"base_clip_count": 2, "base_clip_seconds": 3},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
engine = VideoEngine(cfg)
|
||||||
|
|
||||||
|
log.info("Loading models...")
|
||||||
|
engine.load_models()
|
||||||
|
|
||||||
|
log.info("Pre-baking 2 library clips...")
|
||||||
|
engine.set_avatar(avatar_path)
|
||||||
|
|
||||||
|
assert len(engine.speaking_base_frames) == 2, \
|
||||||
|
f"expected 2 base clips, got {len(engine.speaking_base_frames)}"
|
||||||
|
for i, frames in enumerate(engine.speaking_base_frames):
|
||||||
|
assert isinstance(frames, np.ndarray)
|
||||||
|
assert frames.ndim == 4 and frames.shape[-1] == 3
|
||||||
|
assert frames.dtype == np.uint8
|
||||||
|
log.info(" clip %d: shape=%s", i, frames.shape)
|
||||||
|
|
||||||
|
assert engine.get_idle_clip() is not None
|
||||||
|
log.info("PASS: library pre-bake complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""Phase 5 component test: MuseTalk lip-sync + ffmpeg mux.
|
||||||
|
|
||||||
|
Verifies the full library-mode per-turn path:
|
||||||
|
- Pre-bake a library clip.
|
||||||
|
- Generate a stand-in TTS waveform (sine tone).
|
||||||
|
- Call ``VideoEngine.generate_speaking_clip`` and get a valid MP4 back.
|
||||||
|
|
||||||
|
Writes the resulting clip to ``tests/component/_out/phase5_speaking.mp4``.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_05_musetalk_lipsync
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
from tests.component._common import (
|
||||||
|
ensure_sample_avatar,
|
||||||
|
get_logger,
|
||||||
|
synth_tone,
|
||||||
|
write_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = get_logger("test_05")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
avatar_path = ensure_sample_avatar()
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"mode": "library",
|
||||||
|
"resolution": 480,
|
||||||
|
"fps": 16,
|
||||||
|
"library": {"base_clip_count": 1, "base_clip_seconds": 4},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
engine = VideoEngine(cfg)
|
||||||
|
engine.load_models()
|
||||||
|
engine.set_avatar(avatar_path)
|
||||||
|
|
||||||
|
audio = synth_tone(seconds=3.0, sample_rate=24000, freq=220.0)
|
||||||
|
log.info("Generating library-mode speaking clip (3s audio)...")
|
||||||
|
mp4 = engine.generate_speaking_clip(
|
||||||
|
audio_f32=audio,
|
||||||
|
sample_rate=24000,
|
||||||
|
reply_text="Hello, this is a lip-sync test.",
|
||||||
|
)
|
||||||
|
assert isinstance(mp4, bytes) and len(mp4) > 0
|
||||||
|
assert mp4[4:8] == b"ftyp"
|
||||||
|
out = write_bytes("phase5_speaking.mp4", mp4)
|
||||||
|
log.info("PASS: speaking clip written to %s (%d bytes)", out, len(mp4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
"""Phase 6 component test: reflective mode (fresh Wan2.2 clip per turn).
|
||||||
|
|
||||||
|
Verifies that with ``mode=reflective``, ``generate_speaking_clip`` runs
|
||||||
|
the Wan2.2 image-to-video pipeline once per call (so the base frames
|
||||||
|
differ from turn to turn) and the prompt is derived from the reply text.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_06_reflective
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
from tests.component._common import (
|
||||||
|
ensure_sample_avatar,
|
||||||
|
get_logger,
|
||||||
|
synth_tone,
|
||||||
|
write_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = get_logger("test_06")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
avatar_path = ensure_sample_avatar()
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"mode": "reflective",
|
||||||
|
"resolution": 480,
|
||||||
|
"fps": 16,
|
||||||
|
"reflective": {"clip_seconds": 3},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
engine = VideoEngine(cfg)
|
||||||
|
engine.load_models()
|
||||||
|
engine.set_avatar(avatar_path)
|
||||||
|
|
||||||
|
# Verify prompt derivation includes the reply hint
|
||||||
|
prompt = engine._derive_prompt(
|
||||||
|
"The assistant walks along a sunny beach watching seagulls."
|
||||||
|
)
|
||||||
|
log.info("derived prompt: %s", prompt)
|
||||||
|
assert "beach" in prompt, "reply_hint did not survive template interpolation"
|
||||||
|
|
||||||
|
audio = synth_tone(seconds=3.0)
|
||||||
|
log.info("Generating reflective speaking clip #1...")
|
||||||
|
mp4_a = engine.generate_speaking_clip(
|
||||||
|
audio, 24000, "The assistant walks along a sunny beach watching seagulls."
|
||||||
|
)
|
||||||
|
write_bytes("phase6_reflective_beach.mp4", mp4_a)
|
||||||
|
|
||||||
|
log.info("Generating reflective speaking clip #2...")
|
||||||
|
mp4_b = engine.generate_speaking_clip(
|
||||||
|
audio, 24000, "Now the character stands in a snow-covered forest at dusk."
|
||||||
|
)
|
||||||
|
write_bytes("phase6_reflective_snow.mp4", mp4_b)
|
||||||
|
|
||||||
|
# Not a strict assertion (same prompt could yield identical bytes if seeded),
|
||||||
|
# but with different prompts and random seeds the blobs should differ.
|
||||||
|
if mp4_a != mp4_b:
|
||||||
|
log.info("PASS: reflective clips differ as expected")
|
||||||
|
else:
|
||||||
|
log.warning("clips are byte-identical — check that seeds are random")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
"""Phase 7 component test: HTTP endpoints (/api/set-avatar, /api/idle-clip,
|
||||||
|
/api/set-video-mode, /api/reload-loras, WebSocket handshake video_mode msg).
|
||||||
|
|
||||||
|
Uses FastAPI's ``TestClient`` so we don't need a running uvicorn server.
|
||||||
|
Stubs the model manager to avoid loading Wan2.2 — we only care that the
|
||||||
|
HTTP surface is plumbed correctly.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_07_endpoints
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.component._common import get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_07")
|
||||||
|
|
||||||
|
|
||||||
|
def _stub_video_engine():
|
||||||
|
class StubCfg:
|
||||||
|
mode = "reflective"
|
||||||
|
class StubEngine:
|
||||||
|
cfg = StubCfg()
|
||||||
|
avatar_path = None
|
||||||
|
def __init__(self): self.idle = b"FAKE_MP4"
|
||||||
|
def is_ready(self): return bool(self.avatar_path)
|
||||||
|
def get_idle_clip(self): return self.idle
|
||||||
|
def set_avatar(self, path): self.avatar_path = path
|
||||||
|
def load_loras(self, specs): self._last_loras = specs
|
||||||
|
return StubEngine()
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import server.main as main_mod
|
||||||
|
|
||||||
|
# Inject a stub engine so we never touch Wan2.2.
|
||||||
|
main_mod.model_mgr.video_engine = _stub_video_engine()
|
||||||
|
|
||||||
|
# Bypass the heavy lifespan (model loading) so TestClient starts fast.
|
||||||
|
main_mod.app.router.lifespan_context = None # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
client = TestClient(main_mod.app)
|
||||||
|
|
||||||
|
# --- set-avatar ---
|
||||||
|
log.info("[case 1] POST /api/set-avatar")
|
||||||
|
fake_png = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64 # minimal PNG header
|
||||||
|
resp = client.post(
|
||||||
|
"/api/set-avatar",
|
||||||
|
files={"image": ("avatar.png", io.BytesIO(fake_png), "image/png")},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, f"got {resp.status_code}: {resp.text}"
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["idle_clip_url"] == "/api/idle-clip"
|
||||||
|
log.info(" PASS: %s", data)
|
||||||
|
|
||||||
|
# --- idle-clip ---
|
||||||
|
log.info("[case 2] GET /api/idle-clip")
|
||||||
|
resp = client.get("/api/idle-clip")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == b"FAKE_MP4"
|
||||||
|
assert resp.headers["content-type"] == "video/mp4"
|
||||||
|
log.info(" PASS")
|
||||||
|
|
||||||
|
# --- set-video-mode ---
|
||||||
|
log.info("[case 3] POST /api/set-video-mode")
|
||||||
|
for mode in ("off", "library", "reflective"):
|
||||||
|
resp = client.post("/api/set-video-mode", data={"mode": mode})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["mode"] == mode
|
||||||
|
resp = client.post("/api/set-video-mode", data={"mode": "bogus"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
log.info(" PASS")
|
||||||
|
|
||||||
|
# --- reload-loras ---
|
||||||
|
log.info("[case 4] POST /api/reload-loras")
|
||||||
|
body = {
|
||||||
|
"loras": [
|
||||||
|
{"path": "/cache/loras/a.safetensors", "weight": 0.8,
|
||||||
|
"target": "both", "name": "test-a"},
|
||||||
|
{"path": "/cache/loras/b.safetensors", "weight": 0.4,
|
||||||
|
"target": "both"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
resp = client.post("/api/reload-loras", json=body)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert data["lora_count"] == 2
|
||||||
|
log.info(" PASS: %s", data)
|
||||||
|
|
||||||
|
# --- WebSocket video_mode handshake ---
|
||||||
|
log.info("[case 5] WebSocket /ws/chat → video_mode announcement")
|
||||||
|
with client.websocket_connect("/ws/chat") as websocket:
|
||||||
|
msgs = []
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
msg = websocket.receive_json()
|
||||||
|
msgs.append(msg)
|
||||||
|
if msg.get("type") == "video_mode":
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
assert any(m.get("type") == "video_mode" for m in msgs), msgs
|
||||||
|
log.info(" PASS")
|
||||||
|
|
||||||
|
log.info("ALL PASSED")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
"""Phase 8 component test: /api/reload-loras hot-swap.
|
||||||
|
|
||||||
|
Verifies that ``VideoEngine.load_loras`` can be called again after startup
|
||||||
|
and the idle clip is regenerated to reflect the new style.
|
||||||
|
|
||||||
|
This test is the 'real model' version of test_07's reload endpoint stub.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec voice-chat python -m tests.component.test_08_lora_reload
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from server.video import LoRASpec, VideoConfig, VideoEngine
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||||
|
|
||||||
|
log = get_logger("test_08")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
avatar_path = ensure_sample_avatar()
|
||||||
|
cfg = VideoConfig.from_dict({"enabled": True, "mode": "reflective"})
|
||||||
|
engine = VideoEngine(cfg)
|
||||||
|
engine.load_models()
|
||||||
|
|
||||||
|
# Initial state: no LoRAs
|
||||||
|
engine.set_avatar(avatar_path)
|
||||||
|
idle_a = engine.get_idle_clip()
|
||||||
|
assert idle_a is not None
|
||||||
|
hash_a = hashlib.sha256(idle_a).hexdigest()
|
||||||
|
write_bytes("phase8_idle_noloras.mp4", idle_a)
|
||||||
|
log.info("idle (no LoRAs) sha256=%s", hash_a[:16])
|
||||||
|
|
||||||
|
# Hot-reload flow: unload (no-op), reload empty list, verify clip still generates.
|
||||||
|
# There are no published 5B-Turbo-compatible LoRAs yet; when one exists,
|
||||||
|
# construct a LoRASpec(path=..., target="both", weight=1.0) and compare hashes.
|
||||||
|
engine.load_loras([])
|
||||||
|
engine.set_avatar(avatar_path)
|
||||||
|
idle_b = engine.get_idle_clip()
|
||||||
|
assert idle_b is not None
|
||||||
|
hash_b = hashlib.sha256(idle_b).hexdigest()
|
||||||
|
write_bytes("phase8_idle_reloaded.mp4", idle_b)
|
||||||
|
log.info("idle (post-reload) sha256=%s", hash_b[:16])
|
||||||
|
|
||||||
|
log.info("PASS: hot-reload round-trip completed "
|
||||||
|
"(hash match=%s — expected without a real LoRA applied).",
|
||||||
|
hash_a == hash_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
"""Quick smoke test: generate a video clip with the dense 5B Turbo GGUF pipeline.
|
||||||
|
|
||||||
|
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
|
||||||
|
and writes the result to tests/component/_out/phase9_gguf.mp4.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q8_0 voice-chat \
|
||||||
|
python -m tests.component.test_09_gguf_generate
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||||
|
|
||||||
|
log = get_logger("test_09")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
# Debug: verify DTYPE is set correctly for GGUF
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE
|
||||||
|
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||||
|
|
||||||
|
log.info("Generating 3-second i2v clip...")
|
||||||
|
frames = pipe.generate_i2v(
|
||||||
|
image_path=avatar,
|
||||||
|
prompt="a person looking at the camera, natural lighting, soft focus background",
|
||||||
|
seconds=3,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
log.info("Got frames: shape=%s dtype=%s", frames.shape, frames.dtype)
|
||||||
|
|
||||||
|
# Encode to MP4
|
||||||
|
import imageio.v3 as iio
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||||
|
tmp = tf.name
|
||||||
|
try:
|
||||||
|
iio.imwrite(tmp, frames, fps=16, codec="libx264")
|
||||||
|
with open(tmp, "rb") as f:
|
||||||
|
mp4_bytes = f.read()
|
||||||
|
finally:
|
||||||
|
os.remove(tmp)
|
||||||
|
|
||||||
|
out = write_bytes("phase9_gguf.mp4", mp4_bytes)
|
||||||
|
log.info("PASS: video written to %s (%d bytes, %d frames)", out, len(mp4_bytes), frames.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""Smoke test: T5 text encoding under GGUF pipeline.
|
||||||
|
|
||||||
|
Builds the Wan22Pipeline (loads all weights including DIT) but only
|
||||||
|
exercises the T5 encoder — no image-to-video generation. Validates that
|
||||||
|
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_10_t5_encode
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.component._common import get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_10")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
# Check DTYPE state after init
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE
|
||||||
|
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||||
|
|
||||||
|
# Run only the T5 text encoder
|
||||||
|
runner = pipe._runner
|
||||||
|
prompt = "a person looking at the camera, natural lighting"
|
||||||
|
|
||||||
|
log.info("Running T5 text encoder on prompt: %r", prompt)
|
||||||
|
import copy
|
||||||
|
input_info = copy.deepcopy(pipe._input_info_template)
|
||||||
|
input_info.prompt = prompt
|
||||||
|
|
||||||
|
runner.run_text_encoder(input_info)
|
||||||
|
log.info("T5 encode complete.")
|
||||||
|
|
||||||
|
# Inspect output — check all dataclass fields for tensor results
|
||||||
|
import torch
|
||||||
|
for attr in vars(input_info):
|
||||||
|
val = getattr(input_info, attr)
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
|
||||||
|
|
||||||
|
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
current = GET_DTYPE()
|
||||||
|
import torch
|
||||||
|
expected = torch.float16
|
||||||
|
if current == expected:
|
||||||
|
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
|
||||||
|
else:
|
||||||
|
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Smoke test: image reading + VAE encoder (+ CLIP if enabled) under GGUF pipeline.
|
||||||
|
|
||||||
|
Builds the Wan22Pipeline, loads a sample avatar, reads the image input,
|
||||||
|
runs the CLIP image encoder (if use_image_encoder is true in the config),
|
||||||
|
and runs the VAE encoder. Validates outputs under DTYPE=FP16.
|
||||||
|
|
||||||
|
Note: The GGUF distill config sets use_image_encoder=false, so CLIP is
|
||||||
|
skipped by default. The VAE encoder is always exercised.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_11_image_encode
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_11")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
runner = pipe._runner
|
||||||
|
|
||||||
|
# Set up input_info so runner methods can access it
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||||
|
update_input_info_from_dict(
|
||||||
|
pipe._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": 42,
|
||||||
|
"prompt": "a person looking at the camera, natural lighting",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": avatar,
|
||||||
|
"target_video_length": 17,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner.input_info = pipe._input_info_template
|
||||||
|
|
||||||
|
# 1. Load image
|
||||||
|
log.info("Reading image input...")
|
||||||
|
img, img_ori = runner.read_image_input(avatar)
|
||||||
|
log.info("img: shape=%s dtype=%s device=%s", img.shape, img.dtype, img.device)
|
||||||
|
|
||||||
|
# 2. CLIP image encoder (only if enabled in config)
|
||||||
|
use_clip = runner.config.get("use_image_encoder", True)
|
||||||
|
if use_clip:
|
||||||
|
log.info("Running CLIP image encoder...")
|
||||||
|
clip_out = runner.run_image_encoder(img)
|
||||||
|
log.info("clip_out: shape=%s dtype=%s device=%s", clip_out.shape, clip_out.dtype, clip_out.device)
|
||||||
|
assert isinstance(clip_out, torch.Tensor), f"Expected tensor, got {type(clip_out)}"
|
||||||
|
log.info("PASS: CLIP image encoder succeeded.")
|
||||||
|
else:
|
||||||
|
log.info("CLIP image encoder disabled (use_image_encoder=false) — skipping.")
|
||||||
|
|
||||||
|
# 3. VAE encoder
|
||||||
|
vae_input = img_ori if runner.vae_encoder_need_img_original else img
|
||||||
|
log.info("Running VAE encoder (using %s)...",
|
||||||
|
"img_ori" if runner.vae_encoder_need_img_original else "img tensor")
|
||||||
|
vae_out, latent_shape = runner.run_vae_encoder(vae_input)
|
||||||
|
log.info("latent_shape: %s", latent_shape)
|
||||||
|
if isinstance(vae_out, torch.Tensor):
|
||||||
|
log.info("vae_out: shape=%s dtype=%s device=%s", vae_out.shape, vae_out.dtype, vae_out.device)
|
||||||
|
elif isinstance(vae_out, dict):
|
||||||
|
for k, v in vae_out.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
log.info("vae_out[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
|
||||||
|
else:
|
||||||
|
log.info("vae_out[%s]: type=%s", k, type(v))
|
||||||
|
else:
|
||||||
|
log.info("vae_out: type=%s", type(vae_out))
|
||||||
|
log.info("PASS: VAE encoder succeeded.")
|
||||||
|
|
||||||
|
log.info("PASS: All image encoding stages completed under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
"""Smoke test: single DIT denoising step with GGUF weights.
|
||||||
|
|
||||||
|
Builds the pipeline, runs all encoders, initializes the scheduler, then
|
||||||
|
executes exactly one DIT forward pass (step_pre → infer → step_post).
|
||||||
|
This isolates the GGUF fp16 DIT from the rest of the pipeline.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_12_dit_single_step
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_12")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
runner = pipe._runner
|
||||||
|
|
||||||
|
# Set up input_info for a short clip
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||||
|
update_input_info_from_dict(
|
||||||
|
pipe._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": 42,
|
||||||
|
"prompt": "a person looking at the camera, natural lighting",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": avatar,
|
||||||
|
"target_video_length": 17, # 1 second at 16fps + 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner.input_info = pipe._input_info_template
|
||||||
|
|
||||||
|
# 1. Run all encoders (T5 + CLIP + VAE)
|
||||||
|
log.info("Running all input encoders (T5 + CLIP + VAE)...")
|
||||||
|
runner.inputs = runner.run_input_encoder()
|
||||||
|
log.info("Encoder outputs ready.")
|
||||||
|
for k, v in runner.inputs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
log.info(" inputs[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
for k2, v2 in v.items():
|
||||||
|
if isinstance(v2, torch.Tensor):
|
||||||
|
log.info(" inputs[%s][%s]: shape=%s dtype=%s", k, k2, v2.shape, v2.dtype)
|
||||||
|
|
||||||
|
# 2. Initialize run (sets up scheduler, creates noise latents)
|
||||||
|
log.info("Initializing run (scheduler.prepare)...")
|
||||||
|
runner.init_run()
|
||||||
|
latents = runner.model.scheduler.latents
|
||||||
|
log.info("Initial latents: shape=%s dtype=%s", latents.shape, latents.dtype)
|
||||||
|
|
||||||
|
# 3. Single DIT step
|
||||||
|
log.info("Running single DIT step (step_pre → infer → step_post)...")
|
||||||
|
runner.model.scheduler.step_pre(step_index=0)
|
||||||
|
runner.model.infer(runner.inputs)
|
||||||
|
runner.model.scheduler.step_post()
|
||||||
|
|
||||||
|
latents_after = runner.model.scheduler.latents
|
||||||
|
log.info("Latents after step: shape=%s dtype=%s", latents_after.shape, latents_after.dtype)
|
||||||
|
|
||||||
|
# Verify latents changed (denoising did something)
|
||||||
|
assert not torch.equal(latents, latents_after), "Latents unchanged after DIT step"
|
||||||
|
log.info("PASS: DIT single step completed, latents updated.")
|
||||||
|
|
||||||
|
log.info("PASS: DIT forward pass succeeded under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
"""Smoke test: VAE decoder under GGUF pipeline.
|
||||||
|
|
||||||
|
Builds the pipeline, runs all encoders, initializes the scheduler, executes
|
||||||
|
one DIT denoising step, then decodes the resulting latents back to pixel
|
||||||
|
frames via the VAE decoder. Validates the full encode→denoise→decode path.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_13_vae_decode
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||||
|
|
||||||
|
log = get_logger("test_13")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
|
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-TI2V-5B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
runner = pipe._runner
|
||||||
|
|
||||||
|
# Set up input_info for a short clip
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||||
|
update_input_info_from_dict(
|
||||||
|
pipe._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": 42,
|
||||||
|
"prompt": "a person looking at the camera, natural lighting",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": avatar,
|
||||||
|
"target_video_length": 17, # 1 second at 16fps + 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner.input_info = pipe._input_info_template
|
||||||
|
|
||||||
|
# 1. Run all encoders (T5 + CLIP + VAE)
|
||||||
|
log.info("Running all input encoders (T5 + CLIP + VAE)...")
|
||||||
|
runner.inputs = runner.run_input_encoder()
|
||||||
|
log.info("Encoder outputs ready.")
|
||||||
|
|
||||||
|
# 2. Initialize run (sets up scheduler, creates noise latents)
|
||||||
|
log.info("Initializing run (scheduler.prepare)...")
|
||||||
|
runner.init_run()
|
||||||
|
log.info("Initial latents: shape=%s dtype=%s",
|
||||||
|
runner.model.scheduler.latents.shape,
|
||||||
|
runner.model.scheduler.latents.dtype)
|
||||||
|
|
||||||
|
# 3. Single DIT step (so we have realistic latents to decode)
|
||||||
|
log.info("Running single DIT step...")
|
||||||
|
runner.model.scheduler.step_pre(step_index=0)
|
||||||
|
runner.model.infer(runner.inputs)
|
||||||
|
runner.model.scheduler.step_post()
|
||||||
|
latents = runner.model.scheduler.latents
|
||||||
|
log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype)
|
||||||
|
|
||||||
|
# 4. VAE decode
|
||||||
|
log.info("Running VAE decoder...")
|
||||||
|
video_out = runner.run_vae_decoder(latents)
|
||||||
|
log.info("VAE decoder output type: %s", type(video_out))
|
||||||
|
if isinstance(video_out, torch.Tensor):
|
||||||
|
log.info("video_out: shape=%s dtype=%s device=%s",
|
||||||
|
video_out.shape, video_out.dtype, video_out.device)
|
||||||
|
elif isinstance(video_out, list):
|
||||||
|
log.info("video_out: list of %d items", len(video_out))
|
||||||
|
if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor):
|
||||||
|
log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype)
|
||||||
|
else:
|
||||||
|
log.info("video_out: %s", video_out)
|
||||||
|
|
||||||
|
log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
"""Unit tests for the frame-length fitting helper in server.video_models.musetalk.
|
||||||
|
|
||||||
|
Pure-python: does not import MuseTalk itself.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from server.video_models.musetalk import _fit_frames_to_length, _ensure_uint8_rgb
|
||||||
|
|
||||||
|
|
||||||
|
def _make_frames(t, h=2, w=2):
|
||||||
|
return np.arange(t * h * w * 3, dtype=np.uint8).reshape(t, h, w, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_frames_trim():
|
||||||
|
frames = _make_frames(10)
|
||||||
|
out = _fit_frames_to_length(frames, 4)
|
||||||
|
assert out.shape == (4, 2, 2, 3)
|
||||||
|
np.testing.assert_array_equal(out, frames[:4])
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_frames_passthrough_when_equal():
|
||||||
|
frames = _make_frames(5)
|
||||||
|
out = _fit_frames_to_length(frames, 5)
|
||||||
|
assert out is frames or np.array_equal(out, frames)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_frames_extends_with_pingpong():
|
||||||
|
frames = _make_frames(3)
|
||||||
|
out = _fit_frames_to_length(frames, 8)
|
||||||
|
assert out.shape == (8, 2, 2, 3)
|
||||||
|
# First 3 frames match the original
|
||||||
|
np.testing.assert_array_equal(out[:3], frames)
|
||||||
|
# Next 3 are the reverse (ping-pong)
|
||||||
|
np.testing.assert_array_equal(out[3:6], frames[::-1])
|
||||||
|
# Then forward again
|
||||||
|
np.testing.assert_array_equal(out[6:8], frames[:2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_frames_zero_target_returns_original():
|
||||||
|
frames = _make_frames(3)
|
||||||
|
out = _fit_frames_to_length(frames, 0)
|
||||||
|
np.testing.assert_array_equal(out, frames)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_uint8_rgb_from_float():
|
||||||
|
arr = np.ones((5, 2, 2, 3), dtype=np.float32) * 0.5
|
||||||
|
out = _ensure_uint8_rgb(arr)
|
||||||
|
assert out.dtype == np.uint8
|
||||||
|
assert out.shape == (5, 2, 2, 3)
|
||||||
|
assert out[0, 0, 0, 0] == 127
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_uint8_rgb_promotes_3d_to_4d():
|
||||||
|
arr = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||||
|
out = _ensure_uint8_rgb(arr)
|
||||||
|
assert out.shape == (1, 2, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_uint8_rgb_clips_float_out_of_range():
|
||||||
|
arr = np.ones((1, 1, 1, 3), dtype=np.float32) * 2.0 # 2.0 → clipped to 255
|
||||||
|
out = _ensure_uint8_rgb(arr)
|
||||||
|
assert out[0, 0, 0, 0] == 255
|
||||||
|
arr2 = np.ones((1, 1, 1, 3), dtype=np.float32) * -1.0
|
||||||
|
out2 = _ensure_uint8_rgb(arr2)
|
||||||
|
assert out2[0, 0, 0, 0] == 0
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
"""Unit tests for the ffmpeg muxer.
|
||||||
|
|
||||||
|
Requires ``ffmpeg`` on PATH. On Windows, if ffmpeg is not installed these
|
||||||
|
tests are skipped (they will run inside the Docker image where ffmpeg is
|
||||||
|
always present).
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import struct
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from server.video_models.muxer import frames_and_audio_to_mp4, frames_to_mp4_loop
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
shutil.which("ffmpeg") is None,
|
||||||
|
reason="ffmpeg not installed locally; run these inside Docker",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rgb_frames(t, h=64, w=64):
|
||||||
|
"""Coloured checker frames so the encoder has real content."""
|
||||||
|
frames = np.zeros((t, h, w, 3), dtype=np.uint8)
|
||||||
|
for i in range(t):
|
||||||
|
frames[i, :, :, 0] = (i * 20) % 255
|
||||||
|
frames[i, :h // 2, :, 1] = 255
|
||||||
|
frames[i, :, :w // 2, 2] = 255
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_to_mp4_loop_produces_mp4_bytes():
|
||||||
|
frames = _rgb_frames(8)
|
||||||
|
data = frames_to_mp4_loop(frames, fps=16)
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
assert len(data) > 0
|
||||||
|
# MP4 files start with an ftyp box: 4 bytes size + 'ftyp'
|
||||||
|
assert data[4:8] == b"ftyp"
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_and_audio_to_mp4_produces_mp4_bytes():
|
||||||
|
frames = _rgb_frames(16)
|
||||||
|
# 1s silent audio at 24kHz
|
||||||
|
audio = np.zeros(24000, dtype=np.float32)
|
||||||
|
data = frames_and_audio_to_mp4(frames, audio, sample_rate=24000, fps=16)
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
assert len(data) > 0
|
||||||
|
assert data[4:8] == b"ftyp"
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_to_mp4_loop_rejects_empty():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
frames_to_mp4_loop(np.empty((0, 64, 64, 3), dtype=np.uint8), fps=16)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_and_audio_to_mp4_rejects_empty_audio():
|
||||||
|
frames = _rgb_frames(4)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
frames_and_audio_to_mp4(
|
||||||
|
frames, np.empty(0, dtype=np.float32), sample_rate=24000, fps=16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_to_mp4_loop_rejects_wrong_shape():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
frames_to_mp4_loop(np.zeros((4, 64, 64), dtype=np.uint8), fps=16)
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
"""Unit test for the video-mode branch in ConversationSession.
|
||||||
|
|
||||||
|
Stubs every model involved (ASR, LLM, TTS, VideoEngine) so we can verify:
|
||||||
|
1. When video_engine is not ready, the existing PCM streaming path runs.
|
||||||
|
2. When video_engine IS ready, the per-chunk PCM sends are skipped and a
|
||||||
|
single ``speaking_clip`` JSON + MP4 binary is sent instead.
|
||||||
|
|
||||||
|
Pure asyncio; no CUDA, no real models.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import types
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from server.pipeline import ConversationSession
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeVAD:
|
||||||
|
is_speaking = False
|
||||||
|
def process_chunk(self, _): return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeASR:
|
||||||
|
def __init__(self, text="hello"):
|
||||||
|
self.text = text
|
||||||
|
def transcribe(self, _): return self.text
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self, response="Hi there."):
|
||||||
|
self.response = response
|
||||||
|
def generate(self, *_a, **_k):
|
||||||
|
return self.response, None
|
||||||
|
def trim_cache(self, state, _): return state
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTTSIterable:
|
||||||
|
"""Drop-in replacement for Kokoro's pipeline(..) generator."""
|
||||||
|
def __init__(self, chunks):
|
||||||
|
self._chunks = chunks
|
||||||
|
def __call__(self, segment, voice=None):
|
||||||
|
for i, audio in enumerate(self._chunks):
|
||||||
|
yield f"w{i}", None, audio
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTTSEngine:
|
||||||
|
def __init__(self, chunks):
|
||||||
|
self.pipeline = _FakeTTSIterable(chunks)
|
||||||
|
self.voice = "v"
|
||||||
|
self.sample_rate = 24000
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeVideoEngineReady:
|
||||||
|
class _Cfg:
|
||||||
|
mode = "reflective"
|
||||||
|
cfg = _Cfg()
|
||||||
|
def __init__(self):
|
||||||
|
self.called_with = None
|
||||||
|
def is_ready(self): return True
|
||||||
|
def generate_speaking_clip(self, audio, sr, reply_text):
|
||||||
|
self.called_with = {"len": len(audio), "sr": sr, "reply": reply_text}
|
||||||
|
return b"FAKE_MP4_BYTES"
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeModelsBase:
|
||||||
|
def __init__(self, tts_chunks):
|
||||||
|
self.asr_engine = _FakeASR()
|
||||||
|
self.llm_engine = _FakeLLM()
|
||||||
|
self.tts_engine = _FakeTTSEngine(tts_chunks)
|
||||||
|
def create_vad(self): return _FakeVAD()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeModelsStreaming(_FakeModelsBase):
|
||||||
|
video_engine = None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeModelsVideo(_FakeModelsBase):
|
||||||
|
def __init__(self, tts_chunks):
|
||||||
|
super().__init__(tts_chunks)
|
||||||
|
self.video_engine = _FakeVideoEngineReady()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_path_when_video_engine_absent():
|
||||||
|
json_sent: list = []
|
||||||
|
bytes_sent: list = []
|
||||||
|
|
||||||
|
async def send_json(d): json_sent.append(d)
|
||||||
|
async def send_bytes(b): bytes_sent.append(b)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
np.ones(240, dtype=np.float32),
|
||||||
|
np.ones(480, dtype=np.float32),
|
||||||
|
]
|
||||||
|
models = _FakeModelsStreaming(tts_chunks=chunks)
|
||||||
|
session = ConversationSession(models, send_json, send_bytes)
|
||||||
|
await session._process_utterance(np.zeros(16000, dtype=np.float32))
|
||||||
|
|
||||||
|
# PCM bytes were sent (one per TTS chunk).
|
||||||
|
assert len(bytes_sent) == 2
|
||||||
|
# Per-chunk response_text messages were sent (not video's one-shot).
|
||||||
|
text_msgs = [m for m in json_sent if m.get("type") == "response_text"]
|
||||||
|
assert any(not m.get("final") for m in text_msgs)
|
||||||
|
# No speaking_clip envelope
|
||||||
|
assert not any(m.get("type") == "speaking_clip" for m in json_sent)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_video_path_when_engine_ready():
|
||||||
|
json_sent: list = []
|
||||||
|
bytes_sent: list = []
|
||||||
|
|
||||||
|
async def send_json(d): json_sent.append(d)
|
||||||
|
async def send_bytes(b): bytes_sent.append(b)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
np.full(480, 0.5, dtype=np.float32),
|
||||||
|
np.full(480, 0.25, dtype=np.float32),
|
||||||
|
]
|
||||||
|
models = _FakeModelsVideo(tts_chunks=chunks)
|
||||||
|
session = ConversationSession(models, send_json, send_bytes)
|
||||||
|
await session._process_utterance(np.zeros(16000, dtype=np.float32))
|
||||||
|
|
||||||
|
# MP4 blob was sent once.
|
||||||
|
assert bytes_sent == [b"FAKE_MP4_BYTES"]
|
||||||
|
# speaking_clip envelope was sent exactly once.
|
||||||
|
envelopes = [m for m in json_sent if m.get("type") == "speaking_clip"]
|
||||||
|
assert len(envelopes) == 1
|
||||||
|
assert envelopes[0]["size_bytes"] == len(b"FAKE_MP4_BYTES")
|
||||||
|
assert envelopes[0]["text"] == "Hi there."
|
||||||
|
|
||||||
|
# The video engine received the concatenated audio.
|
||||||
|
ve = models.video_engine
|
||||||
|
assert ve.called_with is not None
|
||||||
|
assert ve.called_with["len"] == 960 # 480 + 480
|
||||||
|
assert ve.called_with["reply"] == "Hi there."
|
||||||
|
|
||||||
|
# No per-chunk PCM bytes were streamed (video path suppresses them).
|
||||||
|
# Only the MP4 blob is in bytes_sent.
|
||||||
|
assert len(bytes_sent) == 1
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
"""Unit tests for VideoConfig parsing and LoRASpec validation.
|
||||||
|
|
||||||
|
Pure-python, no model imports, no CUDA, no ffmpeg. Safe for Windows CI.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from server.video import VideoConfig, LoRASpec
|
||||||
|
|
||||||
|
|
||||||
|
def test_defaults_when_raw_is_empty():
|
||||||
|
cfg = VideoConfig.from_dict({})
|
||||||
|
assert cfg.enabled is False
|
||||||
|
assert cfg.backend == "lightx2v"
|
||||||
|
assert cfg.mode == "reflective"
|
||||||
|
assert cfg.resolution == 480
|
||||||
|
assert cfg.fps == 16
|
||||||
|
assert cfg.library_base_clip_count == 4
|
||||||
|
assert cfg.reflective_prompt_reply_words == 18
|
||||||
|
assert cfg.loras == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_defaults_when_raw_is_none():
|
||||||
|
cfg = VideoConfig.from_dict(None) # type: ignore[arg-type]
|
||||||
|
assert cfg.enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_library_section_override():
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{"enabled": True, "mode": "library", "library": {"base_clip_count": 7, "base_clip_seconds": 3}}
|
||||||
|
)
|
||||||
|
assert cfg.enabled is True
|
||||||
|
assert cfg.mode == "library"
|
||||||
|
assert cfg.library_base_clip_count == 7
|
||||||
|
assert cfg.library_base_clip_seconds == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_reflective_section_override():
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"reflective": {
|
||||||
|
"clip_seconds": 9,
|
||||||
|
"clip_prompt_template": "my template: {reply_hint}",
|
||||||
|
"prompt_reply_words": 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert cfg.reflective_clip_seconds == 9
|
||||||
|
assert cfg.reflective_prompt_template == "my template: {reply_hint}"
|
||||||
|
assert cfg.reflective_prompt_reply_words == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_parse_minimal():
|
||||||
|
cfg = VideoConfig.from_dict({"loras": [{"path": "/tmp/a.safetensors"}]})
|
||||||
|
assert len(cfg.loras) == 1
|
||||||
|
lora = cfg.loras[0]
|
||||||
|
assert lora.path == "/tmp/a.safetensors"
|
||||||
|
assert lora.weight == 1.0
|
||||||
|
assert lora.target == "both"
|
||||||
|
assert lora.name is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_parse_full():
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"loras": [
|
||||||
|
{
|
||||||
|
"path": "/tmp/a.safetensors",
|
||||||
|
"weight": 0.7,
|
||||||
|
"target": "both",
|
||||||
|
"name": "style-a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "/tmp/b.safetensors",
|
||||||
|
"weight": 0.4,
|
||||||
|
"target": "both",
|
||||||
|
"name": "style-b",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert len(cfg.loras) == 2
|
||||||
|
assert cfg.loras[0].target == "both"
|
||||||
|
assert cfg.loras[0].name == "style-a"
|
||||||
|
assert cfg.loras[1].target == "both"
|
||||||
|
assert cfg.loras[1].weight == 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_legacy_moe_target_coerced_to_both():
|
||||||
|
"""Legacy MoE configs with target='high_noise'/'low_noise' get coerced."""
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"loras": [
|
||||||
|
{"path": "/tmp/hi.safetensors", "target": "high_noise"},
|
||||||
|
{"path": "/tmp/lo.safetensors", "target": "low_noise"},
|
||||||
|
{"path": "/tmp/x.safetensors", "target": "bogus"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert all(l.target == "both" for l in cfg.loras)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_entries_without_path_are_dropped():
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{"loras": [{"weight": 0.5}, {"path": "/tmp/ok.safetensors"}, None]}
|
||||||
|
)
|
||||||
|
assert len(cfg.loras) == 1
|
||||||
|
assert cfg.loras[0].path == "/tmp/ok.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_section_override():
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"models": {
|
||||||
|
"wan22_base_repo": "/local/weights/wan22",
|
||||||
|
"wan22_dit_repo": "/local/weights/wan22-dit",
|
||||||
|
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
|
||||||
|
"wan22_config_json": "/local/cfg/turbo.json",
|
||||||
|
"wan22_model_cls": "wan2.2",
|
||||||
|
"musetalk_path": "/local/weights/musetalk",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert cfg.wan22_base_repo == "/local/weights/wan22"
|
||||||
|
assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
|
||||||
|
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
|
||||||
|
assert cfg.wan22_config_json == "/local/cfg/turbo.json"
|
||||||
|
assert cfg.wan22_model_cls == "wan2.2"
|
||||||
|
assert cfg.musetalk_model_path == "/local/weights/musetalk"
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_section_defaults_to_5b_turbo():
|
||||||
|
cfg = VideoConfig.from_dict({})
|
||||||
|
assert cfg.wan22_base_repo == "Wan-AI/Wan2.2-TI2V-5B"
|
||||||
|
assert cfg.wan22_dit_repo == "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||||
|
assert cfg.wan22_dit_quant_scheme == "gguf-Q8_0"
|
||||||
|
assert cfg.wan22_t5_quantized is True
|
||||||
|
assert cfg.wan22_model_cls == "wan2.2"
|
||||||
|
assert cfg.wan22_config_json == "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""Unit tests for pure-python logic inside VideoEngine.
|
||||||
|
|
||||||
|
No models are loaded: we instantiate ``VideoEngine`` and hand-stub its
|
||||||
|
``_wan22`` / ``_musetalk`` attributes to test prompt derivation, library
|
||||||
|
round-robin, and frame fitting.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from server.video import VideoConfig, VideoEngine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine():
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"mode": "reflective",
|
||||||
|
"fps": 16,
|
||||||
|
"reflective": {
|
||||||
|
"clip_prompt_template": "A: {reply_hint} B",
|
||||||
|
"prompt_reply_words": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return VideoEngine(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_prompt_truncates_to_word_limit(engine):
|
||||||
|
out = engine._derive_prompt("one two three four five six seven eight")
|
||||||
|
assert out == "A: one two three four five B"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_prompt_handles_empty_reply(engine):
|
||||||
|
out = engine._derive_prompt("")
|
||||||
|
assert out == "A: calm and friendly B"
|
||||||
|
out2 = engine._derive_prompt(None) # type: ignore[arg-type]
|
||||||
|
assert out2 == "A: calm and friendly B"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_prompt_strips_and_passes_through(engine):
|
||||||
|
out = engine._derive_prompt(" hello world ")
|
||||||
|
assert out == "A: hello world B"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_ready_false_without_models(engine):
|
||||||
|
# Models haven't been loaded — is_ready must be False so the pipeline
|
||||||
|
# falls back to the PCM streaming path.
|
||||||
|
assert engine.is_ready() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_pick_library_frames_round_robin(engine):
|
||||||
|
engine.cfg.mode = "library"
|
||||||
|
engine.cfg.fps = 2
|
||||||
|
# Two base clips, 4 frames each.
|
||||||
|
a = np.tile(np.array([[[[0, 0, 0]]]], dtype=np.uint8), (4, 1, 1, 1))
|
||||||
|
b = np.tile(np.array([[[[255, 255, 255]]]], dtype=np.uint8), (4, 1, 1, 1))
|
||||||
|
engine.speaking_base_frames = [a, b]
|
||||||
|
# 2s of audio at 16kHz → 4 frames at fps=2
|
||||||
|
audio = np.zeros(16000 * 2, dtype=np.float32)
|
||||||
|
|
||||||
|
f1 = engine._pick_library_frames(audio, 16000)
|
||||||
|
f2 = engine._pick_library_frames(audio, 16000)
|
||||||
|
f3 = engine._pick_library_frames(audio, 16000)
|
||||||
|
assert f1.shape == (4, 1, 1, 3)
|
||||||
|
assert f1[0, 0, 0, 0] == 0 # first pick = clip A
|
||||||
|
assert f2[0, 0, 0, 0] == 255 # second pick = clip B
|
||||||
|
assert f3[0, 0, 0, 0] == 0 # wraps back to A
|
||||||
|
|
||||||
|
|
||||||
|
def test_pick_library_frames_trims_to_audio_duration(engine):
|
||||||
|
engine.cfg.mode = "library"
|
||||||
|
engine.cfg.fps = 4
|
||||||
|
frames = np.zeros((20, 1, 1, 3), dtype=np.uint8)
|
||||||
|
engine.speaking_base_frames = [frames]
|
||||||
|
# 1s audio → 4 frames
|
||||||
|
audio = np.zeros(16000, dtype=np.float32)
|
||||||
|
out = engine._pick_library_frames(audio, 16000)
|
||||||
|
assert out.shape == (4, 1, 1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pick_library_frames_loops_for_long_audio(engine):
|
||||||
|
engine.cfg.mode = "library"
|
||||||
|
engine.cfg.fps = 4
|
||||||
|
frames = np.zeros((4, 1, 1, 3), dtype=np.uint8)
|
||||||
|
engine.speaking_base_frames = [frames]
|
||||||
|
# 3s audio → 12 frames, base has only 4
|
||||||
|
audio = np.zeros(16000 * 3, dtype=np.float32)
|
||||||
|
out = engine._pick_library_frames(audio, 16000)
|
||||||
|
assert out.shape == (12, 1, 1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pick_library_frames_raises_when_empty(engine):
|
||||||
|
engine.cfg.mode = "library"
|
||||||
|
engine.speaking_base_frames = []
|
||||||
|
with pytest.raises(RuntimeError, match="no pre-baked base clips"):
|
||||||
|
engine._pick_library_frames(np.zeros(100, dtype=np.float32), 16000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_speaking_clip_raises_when_not_ready(engine):
|
||||||
|
with pytest.raises(RuntimeError, match="not ready"):
|
||||||
|
engine.generate_speaking_clip(
|
||||||
|
audio_f32=np.zeros(100, dtype=np.float32),
|
||||||
|
sample_rate=16000,
|
||||||
|
reply_text="hi",
|
||||||
|
)
|
||||||
+1
Submodule third_party/MuseTalk added at ca5b7a8f28
Reference in New Issue
Block a user