Add LightX2V + Wan2.2-TI2V-5B-Turbo GGUF experiment
Benchmarks the dense 5B Turbo model (Q8_0 GGUF + fp8 T5) as a lower-VRAM alternative to the 14B MoE pipeline. Includes dtype patches for dense WanModel, Wan 2.2 VAE config (48 channels, 16x spatial), and Blackwell fp8 workaround. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
[submodule "third_party/MuseTalk"]
|
||||
path = third_party/MuseTalk
|
||||
url = https://git.hetherman.cloud/bhetherman/MuseTalk.git
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
_out/
|
||||
@@ -0,0 +1,65 @@
|
||||
# LightX2V + Wan2.2-TI2V-5B-Turbo (GGUF) Experiment
|
||||
|
||||
Swap the 14B MoE distill for the dense 5B Turbo model, keeping the LightX2V backend.
|
||||
Hypothesis: half the parameters → lower VRAM footprint (can coexist with the running
|
||||
server) and faster per-step compute, with the Turbo 4-step distill preserving wall time.
|
||||
|
||||
## Config
|
||||
|
||||
- **Model**: `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF` (Q8_0 by default — swap to Q4_K_M via env)
|
||||
- **Base repo** (configs, T5, VAE): `Wan-AI/Wan2.2-TI2V-5B`
|
||||
- **model_cls**: `wan2.2` (dense, single DIT — not MoE)
|
||||
- **Steps**: 4 (Turbo distill)
|
||||
- **Resolution**: 480×480, 81 frames @ 16 fps
|
||||
|
||||
## Key implementation details
|
||||
|
||||
- **Dense model (`wan2.2`)**: Uses single DIT checkpoint, not MoE — requires different dtype patching than the 14B pipeline
|
||||
- **GGUF dequant → fp16**: Requires `DTYPE=FP16` and patches for T5 (bf16→fp16 wrapper), VAE (→fp16), and DIT pre/post weights (fp32→fp16)
|
||||
- **Wan 2.2 VAE**: 48 latent channels with 16× spatial compression (vs 16 channels / 8× for Wan 2.1) — config must set `vae_stride: [4,16,16]` and `num_channels_latents: 48`
|
||||
- **fp8 T5**: Uses `lightx2v/Encoders` fp8 checkpoint (~4.9 GB vs ~11.4 GB bf16)
|
||||
- **Blackwell (SM120)**: Needs `_patch_fp8_scaled_mm_for_blackwell` to replace sgl_kernel's fp8 GEMM
|
||||
|
||||
## Why a separate container
|
||||
|
||||
Reuses the existing `voice-chat-voice-chat` image (LightX2V already installed) but runs
|
||||
under its own compose profile so it doesn't interfere with the live server volumes or
|
||||
startup. Shares the HF cache volume so model downloads are reused.
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# Ensure main image is built
|
||||
docker compose build voice-chat
|
||||
|
||||
# Stage model (downloads base + Turbo Q8 GGUF, ~6 GB)
|
||||
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
|
||||
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/setup_model.py
|
||||
|
||||
# Run benchmark
|
||||
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
|
||||
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
|
||||
```
|
||||
|
||||
Reports peak VRAM and wall time for an 81-frame 480p clip.
|
||||
|
||||
## Results
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Model load | ~43s |
|
||||
| VRAM after load | 6.53 GB |
|
||||
| T5 encode | ~1s |
|
||||
| VAE encode | ~0.5s |
|
||||
|
||||
Awaiting full end-to-end benchmark completion for wall time and peak VRAM.
|
||||
|
||||
## Go / no-go criteria
|
||||
|
||||
- **Go**: < 45s per 81-frame clip AND peak VRAM < 12 GB (leaves ~20 GB for the server)
|
||||
- **No-go**: keep the 14B MoE Q4_K_M pipeline
|
||||
|
||||
### Baselines
|
||||
|
||||
- **vLLM-Omni + fp16 Turbo-5B**: 1663s / 22.5 GB — decisive no-go
|
||||
- **LightX2V + 14B MoE Q4_K_M**: ~30s/clip, ~14.5 GB VRAM (current production pipeline)
|
||||
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by setup_model.py / test_i2v.py.",
|
||||
|
||||
"infer_steps": 4,
|
||||
"target_video_length": 81,
|
||||
"text_len": 512,
|
||||
|
||||
"resize_mode": "adaptive",
|
||||
"resolution": "480p",
|
||||
"target_height": 480,
|
||||
"target_width": 480,
|
||||
"fps": 16,
|
||||
|
||||
"vae_stride": [4, 16, 16],
|
||||
"num_channels_latents": 48,
|
||||
|
||||
"self_attn_1_type": "torch_sdpa",
|
||||
"cross_attn_1_type": "torch_sdpa",
|
||||
"cross_attn_2_type": "torch_sdpa",
|
||||
"modulate_type": "torch",
|
||||
"rope_type": "torch",
|
||||
|
||||
"sample_guide_scale": 1.0,
|
||||
"sample_shift": 5.0,
|
||||
"enable_cfg": false,
|
||||
|
||||
"cpu_offload": false,
|
||||
"offload_granularity": "model",
|
||||
"t5_cpu_offload": true,
|
||||
"vae_cpu_offload": false,
|
||||
|
||||
"use_image_encoder": false,
|
||||
|
||||
"denoising_step_list": [1000, 750, 500, 250],
|
||||
|
||||
"dit_quantized": true,
|
||||
"dit_quant_scheme": "gguf-Q8_0",
|
||||
"t5_quantized": true,
|
||||
"t5_quant_scheme": "fp8-sgl"
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
services:
|
||||
lightx2v-5b:
|
||||
image: voice-chat-voice-chat:latest
|
||||
volumes:
|
||||
- huggingface-cache:/cache/huggingface
|
||||
- ../../:/app
|
||||
working_dir: /app
|
||||
environment:
|
||||
- DTYPE=FP16
|
||||
- HF_HOME=/cache/huggingface
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
shm_size: "8g"
|
||||
ipc: host
|
||||
profiles:
|
||||
- experimental
|
||||
|
||||
volumes:
|
||||
huggingface-cache:
|
||||
name: voice-chat_huggingface-cache
|
||||
external: true
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Stage Wan2.2-TI2V-5B-Turbo GGUF pipeline for LightX2V.
|
||||
|
||||
Downloads:
|
||||
1. Base `Wan-AI/Wan2.2-TI2V-5B` snapshot (configs, T5, VAE — skip bf16 DIT shards).
|
||||
2. Turbo Q8 GGUF DIT from `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF`.
|
||||
|
||||
Quant file can be overridden via GGUF_FILE env (default Q8_0).
|
||||
|
||||
Idempotent: huggingface_hub handles caching.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
GGUF_FILE = os.environ.get(
|
||||
"GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf"
|
||||
)
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print(f"\n=== 1/2 Snapshot base pipeline {BASE_REPO} ===", flush=True)
|
||||
# The base repo ships bf16 DIT shards we don't need (we use the Turbo GGUF instead).
|
||||
base_dir = snapshot_download(
|
||||
repo_id=BASE_REPO,
|
||||
ignore_patterns=[
|
||||
"*.pt",
|
||||
"diffusion_pytorch_model*.safetensors",
|
||||
],
|
||||
)
|
||||
print(f"Base pipeline at: {base_dir}")
|
||||
|
||||
print(f"\n=== 2/3 Download {GGUF_FILE} from {GGUF_REPO} ===", flush=True)
|
||||
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
|
||||
print(f"GGUF DIT at: {gguf_path}")
|
||||
|
||||
print(f"\n=== 3/3 Download fp8 T5 from {T5_FP8_REPO} ===", flush=True)
|
||||
t5_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||
print(f"fp8 T5 at: {t5_path}")
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print("Ready. Export to test_i2v.py via env:")
|
||||
print(f" BASE_DIR={base_dir}")
|
||||
print(f" DIT_GGUF={gguf_path}")
|
||||
print(f" T5_FP8={t5_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,217 @@
|
||||
"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V.
|
||||
|
||||
Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint.
|
||||
Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16
|
||||
wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16).
|
||||
|
||||
Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar.
|
||||
|
||||
Run:
|
||||
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\
|
||||
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf")
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
CONFIG_JSON = Path(__file__).parent / "config.json"
|
||||
SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png"
|
||||
OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4")
|
||||
|
||||
PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background"
|
||||
SEED = 42
|
||||
|
||||
|
||||
def human_gb(n: int) -> str:
|
||||
return f"{n / (1024 ** 3):.2f} GB"
|
||||
|
||||
|
||||
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||
"""Recursively find and cast fp32 tensors to fp16 on any object tree."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited or depth > 6:
|
||||
return 0
|
||||
visited.add(obj_id)
|
||||
n = 0
|
||||
for attr_name in dir(obj):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(obj, attr_name)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||
try:
|
||||
setattr(obj, attr_name, val.to(torch.float16))
|
||||
n += 1
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(val, "__dict__") and not callable(val):
|
||||
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||
return n
|
||||
|
||||
|
||||
def build_args(model_path: str, config_json: str) -> argparse.Namespace:
|
||||
return argparse.Namespace(
|
||||
seed=SEED,
|
||||
model_cls="wan2.2",
|
||||
task="i2v",
|
||||
support_tasks=[],
|
||||
model_path=model_path,
|
||||
sf_model_path=None,
|
||||
config_json=config_json,
|
||||
use_prompt_enhancer=False,
|
||||
prompt="",
|
||||
negative_prompt="",
|
||||
image_path="",
|
||||
last_frame_path="",
|
||||
audio_path="",
|
||||
image_strength="1.0",
|
||||
image_frame_idx="",
|
||||
src_ref_images=None,
|
||||
src_video=None,
|
||||
src_mask=None,
|
||||
src_pose_path=None,
|
||||
src_face_path=None,
|
||||
src_bg_path=None,
|
||||
src_mask_path=None,
|
||||
pose=None,
|
||||
action_path=None,
|
||||
action_ckpt=None,
|
||||
save_result_path=None,
|
||||
return_result_tensor=False,
|
||||
target_shape=[],
|
||||
target_video_length=81,
|
||||
aspect_ratio="",
|
||||
video_path=None,
|
||||
sr_ratio=2.0,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n=== Stage model ===", flush=True)
|
||||
base_dir = snapshot_download(
|
||||
repo_id=BASE_REPO,
|
||||
ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"],
|
||||
)
|
||||
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
|
||||
t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||
print(f" base_dir: {base_dir}")
|
||||
print(f" dit_gguf: {gguf_path}")
|
||||
print(f" t5_fp8: {t5_fp8_path}")
|
||||
|
||||
with open(CONFIG_JSON, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
cfg.pop("_comment", None)
|
||||
cfg["dit_quantized_ckpt"] = gguf_path
|
||||
cfg["t5_quantized_ckpt"] = t5_fp8_path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
prefix="wan22_5b_", suffix=".json",
|
||||
mode="w", delete=False, encoding="utf-8",
|
||||
)
|
||||
json.dump(cfg, tmp, indent=2)
|
||||
tmp.close()
|
||||
print(f" runtime config: {tmp.name}")
|
||||
|
||||
# Import LightX2V after env is set.
|
||||
import sys
|
||||
if "/app" not in sys.path:
|
||||
sys.path.insert(0, "/app")
|
||||
from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell
|
||||
from lightx2v.infer import init_runner
|
||||
from lightx2v.utils.input_info import (
|
||||
init_empty_input_info,
|
||||
update_input_info_from_dict,
|
||||
)
|
||||
from lightx2v.utils.set_config import set_config
|
||||
|
||||
_patch_fp8_scaled_mm_for_blackwell()
|
||||
|
||||
args = build_args(base_dir, tmp.name)
|
||||
print(f"\n=== set_config + init_runner ===", flush=True)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
t_load = time.perf_counter()
|
||||
config = set_config(args)
|
||||
runner = init_runner(config)
|
||||
load_time = time.perf_counter() - t_load
|
||||
vram_load = torch.cuda.memory_allocated()
|
||||
print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True)
|
||||
|
||||
# GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match.
|
||||
os.environ["DTYPE"] = "FP16"
|
||||
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE
|
||||
GET_DTYPE.cache_clear()
|
||||
|
||||
# Reuse the patch methods from Wan22Pipeline as standalone functions
|
||||
# by constructing a minimal object with ._runner.
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
|
||||
shim = Wan22Pipeline.__new__(Wan22Pipeline)
|
||||
shim._runner = runner
|
||||
Wan22Pipeline._patch_t5_dtype_for_gguf(shim)
|
||||
Wan22Pipeline._patch_vae_dtype_for_gguf(shim)
|
||||
Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim)
|
||||
|
||||
# Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model).
|
||||
# For dense wan2.2, the model is directly at runner.model — patch it explicitly.
|
||||
n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model)
|
||||
print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)")
|
||||
# Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc).
|
||||
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||
print(f"[patch] cast {n_extra} additional fp32 tensors to fp16")
|
||||
|
||||
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||
|
||||
print(f"\n=== generate ===", flush=True)
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||
out_path = tf.name
|
||||
|
||||
update_input_info_from_dict(
|
||||
input_info,
|
||||
{
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "",
|
||||
"image_path": SAMPLE_AVATAR,
|
||||
"save_result_path": out_path,
|
||||
"target_video_length": 81,
|
||||
"return_result_tensor": False,
|
||||
},
|
||||
)
|
||||
|
||||
t_gen = time.perf_counter()
|
||||
runner.run_pipeline(input_info)
|
||||
gen_time = time.perf_counter() - t_gen
|
||||
peak_vram = torch.cuda.max_memory_allocated()
|
||||
print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True)
|
||||
|
||||
if os.path.exists(out_path):
|
||||
size = os.path.getsize(out_path)
|
||||
OUTPUT.write_bytes(Path(out_path).read_bytes())
|
||||
os.remove(out_path)
|
||||
print(f"[output] wrote {OUTPUT} ({size} bytes)")
|
||||
else:
|
||||
print(f"[output] no file at {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+1
Submodule third_party/MuseTalk added at ca5b7a8f28
Reference in New Issue
Block a user