"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V. Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint. Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16 wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16). Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar. Run: docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\ run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py """ from __future__ import annotations import argparse import json import os import tempfile import time from pathlib import Path import torch from huggingface_hub import hf_hub_download, snapshot_download BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B" GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF" GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf") T5_FP8_REPO = "lightx2v/Encoders" T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors" CONFIG_JSON = Path(__file__).parent / "config.json" SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png" OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4") PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background" SEED = 42 def human_gb(n: int) -> str: return f"{n / (1024 ** 3):.2f} GB" def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int: """Recursively find and cast fp32 tensors to fp16 on any object tree.""" if visited is None: visited = set() obj_id = id(obj) if obj_id in visited or depth > 6: return 0 visited.add(obj_id) n = 0 for attr_name in dir(obj): if attr_name.startswith("__"): continue try: val = getattr(obj, attr_name) except Exception: continue if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0: try: setattr(obj, attr_name, val.to(torch.float16)) n += 1 except Exception: pass elif hasattr(val, "__dict__") and not callable(val): n += _cast_all_fp32_tensors(val, visited, depth + 1) return n def build_args(model_path: str, config_json: str) -> argparse.Namespace: return argparse.Namespace( seed=SEED, model_cls="wan2.2", task="i2v", support_tasks=[], model_path=model_path, sf_model_path=None, config_json=config_json, use_prompt_enhancer=False, prompt="", negative_prompt="", image_path="", last_frame_path="", audio_path="", image_strength="1.0", image_frame_idx="", src_ref_images=None, src_video=None, src_mask=None, src_pose_path=None, src_face_path=None, src_bg_path=None, src_mask_path=None, pose=None, action_path=None, action_ckpt=None, save_result_path=None, return_result_tensor=False, target_shape=[], target_video_length=81, aspect_ratio="", video_path=None, sr_ratio=2.0, ) def main() -> None: OUTPUT.parent.mkdir(parents=True, exist_ok=True) print(f"\n=== Stage model ===", flush=True) base_dir = snapshot_download( repo_id=BASE_REPO, ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"], ) gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE) t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE) print(f" base_dir: {base_dir}") print(f" dit_gguf: {gguf_path}") print(f" t5_fp8: {t5_fp8_path}") with open(CONFIG_JSON, "r", encoding="utf-8") as f: cfg = json.load(f) cfg.pop("_comment", None) cfg["dit_quantized_ckpt"] = gguf_path cfg["t5_quantized_ckpt"] = t5_fp8_path tmp = tempfile.NamedTemporaryFile( prefix="wan22_5b_", suffix=".json", mode="w", delete=False, encoding="utf-8", ) json.dump(cfg, tmp, indent=2) tmp.close() print(f" runtime config: {tmp.name}") # Import LightX2V after env is set. import sys if "/app" not in sys.path: sys.path.insert(0, "/app") from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell from lightx2v.infer import init_runner from lightx2v.utils.input_info import ( init_empty_input_info, update_input_info_from_dict, ) from lightx2v.utils.set_config import set_config _patch_fp8_scaled_mm_for_blackwell() args = build_args(base_dir, tmp.name) print(f"\n=== set_config + init_runner ===", flush=True) torch.cuda.reset_peak_memory_stats() t_load = time.perf_counter() config = set_config(args) runner = init_runner(config) load_time = time.perf_counter() - t_load vram_load = torch.cuda.memory_allocated() print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True) # GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match. os.environ["DTYPE"] = "FP16" from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE GET_DTYPE.cache_clear() # Reuse the patch methods from Wan22Pipeline as standalone functions # by constructing a minimal object with ._runner. from server.video_models.wan22 import Wan22Pipeline shim = Wan22Pipeline.__new__(Wan22Pipeline) shim._runner = runner Wan22Pipeline._patch_t5_dtype_for_gguf(shim) Wan22Pipeline._patch_vae_dtype_for_gguf(shim) Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim) # Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model). # For dense wan2.2, the model is directly at runner.model — patch it explicitly. n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model) print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)") # Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc). n_extra = _cast_all_fp32_tensors(runner.model) print(f"[patch] cast {n_extra} additional fp32 tensors to fp16") input_info = init_empty_input_info(args.task, args.support_tasks) print(f"\n=== generate ===", flush=True) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf: out_path = tf.name update_input_info_from_dict( input_info, { "seed": SEED, "prompt": PROMPT, "negative_prompt": "", "image_path": SAMPLE_AVATAR, "save_result_path": out_path, "target_video_length": 81, "return_result_tensor": False, }, ) t_gen = time.perf_counter() runner.run_pipeline(input_info) gen_time = time.perf_counter() - t_gen peak_vram = torch.cuda.max_memory_allocated() print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True) if os.path.exists(out_path): size = os.path.getsize(out_path) OUTPUT.write_bytes(Path(out_path).read_bytes()) os.remove(out_path) print(f"[output] wrote {OUTPUT} ({size} bytes)") else: print(f"[output] no file at {out_path}") if __name__ == "__main__": main()