Files
live-voice-chat/experimental/lightx2v_5b/test_i2v.py
T
bhetherman 9debc56137 Add LightX2V + Wan2.2-TI2V-5B-Turbo GGUF experiment
Benchmarks the dense 5B Turbo model (Q8_0 GGUF + fp8 T5) as a
lower-VRAM alternative to the 14B MoE pipeline. Includes dtype
patches for dense WanModel, Wan 2.2 VAE config (48 channels, 16x
spatial), and Blackwell fp8 workaround.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 01:27:45 -04:00

218 lines
7.1 KiB
Python

"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V.
Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint.
Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16
wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16).
Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar.
Run:
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
"""
from __future__ import annotations
import argparse
import json
import os
import tempfile
import time
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download, snapshot_download
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf")
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
CONFIG_JSON = Path(__file__).parent / "config.json"
SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png"
OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4")
PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background"
SEED = 42
def human_gb(n: int) -> str:
return f"{n / (1024 ** 3):.2f} GB"
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
"""Recursively find and cast fp32 tensors to fp16 on any object tree."""
if visited is None:
visited = set()
obj_id = id(obj)
if obj_id in visited or depth > 6:
return 0
visited.add(obj_id)
n = 0
for attr_name in dir(obj):
if attr_name.startswith("__"):
continue
try:
val = getattr(obj, attr_name)
except Exception:
continue
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
try:
setattr(obj, attr_name, val.to(torch.float16))
n += 1
except Exception:
pass
elif hasattr(val, "__dict__") and not callable(val):
n += _cast_all_fp32_tensors(val, visited, depth + 1)
return n
def build_args(model_path: str, config_json: str) -> argparse.Namespace:
return argparse.Namespace(
seed=SEED,
model_cls="wan2.2",
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
def main() -> None:
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
print(f"\n=== Stage model ===", flush=True)
base_dir = snapshot_download(
repo_id=BASE_REPO,
ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"],
)
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
print(f" base_dir: {base_dir}")
print(f" dit_gguf: {gguf_path}")
print(f" t5_fp8: {t5_fp8_path}")
with open(CONFIG_JSON, "r", encoding="utf-8") as f:
cfg = json.load(f)
cfg.pop("_comment", None)
cfg["dit_quantized_ckpt"] = gguf_path
cfg["t5_quantized_ckpt"] = t5_fp8_path
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_5b_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
print(f" runtime config: {tmp.name}")
# Import LightX2V after env is set.
import sys
if "/app" not in sys.path:
sys.path.insert(0, "/app")
from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell
from lightx2v.infer import init_runner
from lightx2v.utils.input_info import (
init_empty_input_info,
update_input_info_from_dict,
)
from lightx2v.utils.set_config import set_config
_patch_fp8_scaled_mm_for_blackwell()
args = build_args(base_dir, tmp.name)
print(f"\n=== set_config + init_runner ===", flush=True)
torch.cuda.reset_peak_memory_stats()
t_load = time.perf_counter()
config = set_config(args)
runner = init_runner(config)
load_time = time.perf_counter() - t_load
vram_load = torch.cuda.memory_allocated()
print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True)
# GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match.
os.environ["DTYPE"] = "FP16"
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE
GET_DTYPE.cache_clear()
# Reuse the patch methods from Wan22Pipeline as standalone functions
# by constructing a minimal object with ._runner.
from server.video_models.wan22 import Wan22Pipeline
shim = Wan22Pipeline.__new__(Wan22Pipeline)
shim._runner = runner
Wan22Pipeline._patch_t5_dtype_for_gguf(shim)
Wan22Pipeline._patch_vae_dtype_for_gguf(shim)
Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim)
# Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model).
# For dense wan2.2, the model is directly at runner.model — patch it explicitly.
n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model)
print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)")
# Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc).
n_extra = _cast_all_fp32_tensors(runner.model)
print(f"[patch] cast {n_extra} additional fp32 tensors to fp16")
input_info = init_empty_input_info(args.task, args.support_tasks)
print(f"\n=== generate ===", flush=True)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
update_input_info_from_dict(
input_info,
{
"seed": SEED,
"prompt": PROMPT,
"negative_prompt": "",
"image_path": SAMPLE_AVATAR,
"save_result_path": out_path,
"target_video_length": 81,
"return_result_tensor": False,
},
)
t_gen = time.perf_counter()
runner.run_pipeline(input_info)
gen_time = time.perf_counter() - t_gen
peak_vram = torch.cuda.max_memory_allocated()
print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True)
if os.path.exists(out_path):
size = os.path.getsize(out_path)
OUTPUT.write_bytes(Path(out_path).read_bytes())
os.remove(out_path)
print(f"[output] wrote {OUTPUT} ({size} bytes)")
else:
print(f"[output] no file at {out_path}")
if __name__ == "__main__":
main()