Add LightX2V + Wan2.2-TI2V-5B-Turbo GGUF experiment

Benchmarks the dense 5B Turbo model (Q8_0 GGUF + fp8 T5) as a
lower-VRAM alternative to the 14B MoE pipeline. Includes dtype
patches for dense WanModel, Wan 2.2 VAE config (48 channels, 16x
spatial), and Blackwell fp8 workaround.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-16 01:27:45 -04:00
parent 56923ff424
commit 9debc56137
8 changed files with 407 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
[submodule "third_party/MuseTalk"]
path = third_party/MuseTalk
url = https://git.hetherman.cloud/bhetherman/MuseTalk.git
+1
View File
@@ -0,0 +1 @@
_out/
+65
View File
@@ -0,0 +1,65 @@
# LightX2V + Wan2.2-TI2V-5B-Turbo (GGUF) Experiment
Swap the 14B MoE distill for the dense 5B Turbo model, keeping the LightX2V backend.
Hypothesis: half the parameters → lower VRAM footprint (can coexist with the running
server) and faster per-step compute, with the Turbo 4-step distill preserving wall time.
## Config
- **Model**: `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF` (Q8_0 by default — swap to Q4_K_M via env)
- **Base repo** (configs, T5, VAE): `Wan-AI/Wan2.2-TI2V-5B`
- **model_cls**: `wan2.2` (dense, single DIT — not MoE)
- **Steps**: 4 (Turbo distill)
- **Resolution**: 480×480, 81 frames @ 16 fps
## Key implementation details
- **Dense model (`wan2.2`)**: Uses single DIT checkpoint, not MoE — requires different dtype patching than the 14B pipeline
- **GGUF dequant → fp16**: Requires `DTYPE=FP16` and patches for T5 (bf16→fp16 wrapper), VAE (→fp16), and DIT pre/post weights (fp32→fp16)
- **Wan 2.2 VAE**: 48 latent channels with 16× spatial compression (vs 16 channels / 8× for Wan 2.1) — config must set `vae_stride: [4,16,16]` and `num_channels_latents: 48`
- **fp8 T5**: Uses `lightx2v/Encoders` fp8 checkpoint (~4.9 GB vs ~11.4 GB bf16)
- **Blackwell (SM120)**: Needs `_patch_fp8_scaled_mm_for_blackwell` to replace sgl_kernel's fp8 GEMM
## Why a separate container
Reuses the existing `voice-chat-voice-chat` image (LightX2V already installed) but runs
under its own compose profile so it doesn't interfere with the live server volumes or
startup. Shares the HF cache volume so model downloads are reused.
## Running
```bash
# Ensure main image is built
docker compose build voice-chat
# Stage model (downloads base + Turbo Q8 GGUF, ~6 GB)
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/setup_model.py
# Run benchmark
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
```
Reports peak VRAM and wall time for an 81-frame 480p clip.
## Results
| Metric | Value |
|--------|-------|
| Model load | ~43s |
| VRAM after load | 6.53 GB |
| T5 encode | ~1s |
| VAE encode | ~0.5s |
Awaiting full end-to-end benchmark completion for wall time and peak VRAM.
## Go / no-go criteria
- **Go**: < 45s per 81-frame clip AND peak VRAM < 12 GB (leaves ~20 GB for the server)
- **No-go**: keep the 14B MoE Q4_K_M pipeline
### Baselines
- **vLLM-Omni + fp16 Turbo-5B**: 1663s / 22.5 GB — decisive no-go
- **LightX2V + 14B MoE Q4_K_M**: ~30s/clip, ~14.5 GB VRAM (current production pipeline)
+40
View File
@@ -0,0 +1,40 @@
{
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by setup_model.py / test_i2v.py.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"vae_stride": [4, 16, 16],
"num_channels_latents": 48,
"self_attn_1_type": "torch_sdpa",
"cross_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "torch_sdpa",
"modulate_type": "torch",
"rope_type": "torch",
"sample_guide_scale": 1.0,
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"offload_granularity": "model",
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q8_0",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl"
}
@@ -0,0 +1,26 @@
services:
lightx2v-5b:
image: voice-chat-voice-chat:latest
volumes:
- huggingface-cache:/cache/huggingface
- ../../:/app
working_dir: /app
environment:
- DTYPE=FP16
- HF_HOME=/cache/huggingface
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
shm_size: "8g"
ipc: host
profiles:
- experimental
volumes:
huggingface-cache:
name: voice-chat_huggingface-cache
external: true
+54
View File
@@ -0,0 +1,54 @@
"""Stage Wan2.2-TI2V-5B-Turbo GGUF pipeline for LightX2V.
Downloads:
1. Base `Wan-AI/Wan2.2-TI2V-5B` snapshot (configs, T5, VAE — skip bf16 DIT shards).
2. Turbo Q8 GGUF DIT from `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF`.
Quant file can be overridden via GGUF_FILE env (default Q8_0).
Idempotent: huggingface_hub handles caching.
"""
from __future__ import annotations
import os
from huggingface_hub import hf_hub_download, snapshot_download
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
GGUF_FILE = os.environ.get(
"GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf"
)
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
def main() -> None:
print(f"\n=== 1/2 Snapshot base pipeline {BASE_REPO} ===", flush=True)
# The base repo ships bf16 DIT shards we don't need (we use the Turbo GGUF instead).
base_dir = snapshot_download(
repo_id=BASE_REPO,
ignore_patterns=[
"*.pt",
"diffusion_pytorch_model*.safetensors",
],
)
print(f"Base pipeline at: {base_dir}")
print(f"\n=== 2/3 Download {GGUF_FILE} from {GGUF_REPO} ===", flush=True)
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
print(f"GGUF DIT at: {gguf_path}")
print(f"\n=== 3/3 Download fp8 T5 from {T5_FP8_REPO} ===", flush=True)
t5_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
print(f"fp8 T5 at: {t5_path}")
print(f"\n{'=' * 50}")
print("Ready. Export to test_i2v.py via env:")
print(f" BASE_DIR={base_dir}")
print(f" DIT_GGUF={gguf_path}")
print(f" T5_FP8={t5_path}")
if __name__ == "__main__":
main()
+217
View File
@@ -0,0 +1,217 @@
"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V.
Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint.
Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16
wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16).
Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar.
Run:
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
"""
from __future__ import annotations
import argparse
import json
import os
import tempfile
import time
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download, snapshot_download
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf")
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
CONFIG_JSON = Path(__file__).parent / "config.json"
SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png"
OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4")
PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background"
SEED = 42
def human_gb(n: int) -> str:
return f"{n / (1024 ** 3):.2f} GB"
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
"""Recursively find and cast fp32 tensors to fp16 on any object tree."""
if visited is None:
visited = set()
obj_id = id(obj)
if obj_id in visited or depth > 6:
return 0
visited.add(obj_id)
n = 0
for attr_name in dir(obj):
if attr_name.startswith("__"):
continue
try:
val = getattr(obj, attr_name)
except Exception:
continue
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
try:
setattr(obj, attr_name, val.to(torch.float16))
n += 1
except Exception:
pass
elif hasattr(val, "__dict__") and not callable(val):
n += _cast_all_fp32_tensors(val, visited, depth + 1)
return n
def build_args(model_path: str, config_json: str) -> argparse.Namespace:
return argparse.Namespace(
seed=SEED,
model_cls="wan2.2",
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
def main() -> None:
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
print(f"\n=== Stage model ===", flush=True)
base_dir = snapshot_download(
repo_id=BASE_REPO,
ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"],
)
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
print(f" base_dir: {base_dir}")
print(f" dit_gguf: {gguf_path}")
print(f" t5_fp8: {t5_fp8_path}")
with open(CONFIG_JSON, "r", encoding="utf-8") as f:
cfg = json.load(f)
cfg.pop("_comment", None)
cfg["dit_quantized_ckpt"] = gguf_path
cfg["t5_quantized_ckpt"] = t5_fp8_path
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_5b_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
print(f" runtime config: {tmp.name}")
# Import LightX2V after env is set.
import sys
if "/app" not in sys.path:
sys.path.insert(0, "/app")
from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell
from lightx2v.infer import init_runner
from lightx2v.utils.input_info import (
init_empty_input_info,
update_input_info_from_dict,
)
from lightx2v.utils.set_config import set_config
_patch_fp8_scaled_mm_for_blackwell()
args = build_args(base_dir, tmp.name)
print(f"\n=== set_config + init_runner ===", flush=True)
torch.cuda.reset_peak_memory_stats()
t_load = time.perf_counter()
config = set_config(args)
runner = init_runner(config)
load_time = time.perf_counter() - t_load
vram_load = torch.cuda.memory_allocated()
print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True)
# GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match.
os.environ["DTYPE"] = "FP16"
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE
GET_DTYPE.cache_clear()
# Reuse the patch methods from Wan22Pipeline as standalone functions
# by constructing a minimal object with ._runner.
from server.video_models.wan22 import Wan22Pipeline
shim = Wan22Pipeline.__new__(Wan22Pipeline)
shim._runner = runner
Wan22Pipeline._patch_t5_dtype_for_gguf(shim)
Wan22Pipeline._patch_vae_dtype_for_gguf(shim)
Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim)
# Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model).
# For dense wan2.2, the model is directly at runner.model — patch it explicitly.
n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model)
print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)")
# Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc).
n_extra = _cast_all_fp32_tensors(runner.model)
print(f"[patch] cast {n_extra} additional fp32 tensors to fp16")
input_info = init_empty_input_info(args.task, args.support_tasks)
print(f"\n=== generate ===", flush=True)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
update_input_info_from_dict(
input_info,
{
"seed": SEED,
"prompt": PROMPT,
"negative_prompt": "",
"image_path": SAMPLE_AVATAR,
"save_result_path": out_path,
"target_video_length": 81,
"return_result_tensor": False,
},
)
t_gen = time.perf_counter()
runner.run_pipeline(input_info)
gen_time = time.perf_counter() - t_gen
peak_vram = torch.cuda.max_memory_allocated()
print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True)
if os.path.exists(out_path):
size = os.path.getsize(out_path)
OUTPUT.write_bytes(Path(out_path).read_bytes())
os.remove(out_path)
print(f"[output] wrote {OUTPUT} ({size} bytes)")
else:
print(f"[output] no file at {out_path}")
if __name__ == "__main__":
main()
Vendored Submodule
+1
Submodule third_party/MuseTalk added at ca5b7a8f28