9debc56137
Benchmarks the dense 5B Turbo model (Q8_0 GGUF + fp8 T5) as a lower-VRAM alternative to the 14B MoE pipeline. Includes dtype patches for dense WanModel, Wan 2.2 VAE config (48 channels, 16x spatial), and Blackwell fp8 workaround. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
218 lines
7.1 KiB
Python
218 lines
7.1 KiB
Python
"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V.
|
|
|
|
Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint.
|
|
Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16
|
|
wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16).
|
|
|
|
Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar.
|
|
|
|
Run:
|
|
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\
|
|
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
|
|
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
|
GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf")
|
|
T5_FP8_REPO = "lightx2v/Encoders"
|
|
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
|
|
|
CONFIG_JSON = Path(__file__).parent / "config.json"
|
|
SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png"
|
|
OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4")
|
|
|
|
PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background"
|
|
SEED = 42
|
|
|
|
|
|
def human_gb(n: int) -> str:
|
|
return f"{n / (1024 ** 3):.2f} GB"
|
|
|
|
|
|
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
|
"""Recursively find and cast fp32 tensors to fp16 on any object tree."""
|
|
if visited is None:
|
|
visited = set()
|
|
obj_id = id(obj)
|
|
if obj_id in visited or depth > 6:
|
|
return 0
|
|
visited.add(obj_id)
|
|
n = 0
|
|
for attr_name in dir(obj):
|
|
if attr_name.startswith("__"):
|
|
continue
|
|
try:
|
|
val = getattr(obj, attr_name)
|
|
except Exception:
|
|
continue
|
|
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
|
try:
|
|
setattr(obj, attr_name, val.to(torch.float16))
|
|
n += 1
|
|
except Exception:
|
|
pass
|
|
elif hasattr(val, "__dict__") and not callable(val):
|
|
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
|
return n
|
|
|
|
|
|
def build_args(model_path: str, config_json: str) -> argparse.Namespace:
|
|
return argparse.Namespace(
|
|
seed=SEED,
|
|
model_cls="wan2.2",
|
|
task="i2v",
|
|
support_tasks=[],
|
|
model_path=model_path,
|
|
sf_model_path=None,
|
|
config_json=config_json,
|
|
use_prompt_enhancer=False,
|
|
prompt="",
|
|
negative_prompt="",
|
|
image_path="",
|
|
last_frame_path="",
|
|
audio_path="",
|
|
image_strength="1.0",
|
|
image_frame_idx="",
|
|
src_ref_images=None,
|
|
src_video=None,
|
|
src_mask=None,
|
|
src_pose_path=None,
|
|
src_face_path=None,
|
|
src_bg_path=None,
|
|
src_mask_path=None,
|
|
pose=None,
|
|
action_path=None,
|
|
action_ckpt=None,
|
|
save_result_path=None,
|
|
return_result_tensor=False,
|
|
target_shape=[],
|
|
target_video_length=81,
|
|
aspect_ratio="",
|
|
video_path=None,
|
|
sr_ratio=2.0,
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"\n=== Stage model ===", flush=True)
|
|
base_dir = snapshot_download(
|
|
repo_id=BASE_REPO,
|
|
ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"],
|
|
)
|
|
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
|
|
t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
|
print(f" base_dir: {base_dir}")
|
|
print(f" dit_gguf: {gguf_path}")
|
|
print(f" t5_fp8: {t5_fp8_path}")
|
|
|
|
with open(CONFIG_JSON, "r", encoding="utf-8") as f:
|
|
cfg = json.load(f)
|
|
cfg.pop("_comment", None)
|
|
cfg["dit_quantized_ckpt"] = gguf_path
|
|
cfg["t5_quantized_ckpt"] = t5_fp8_path
|
|
|
|
tmp = tempfile.NamedTemporaryFile(
|
|
prefix="wan22_5b_", suffix=".json",
|
|
mode="w", delete=False, encoding="utf-8",
|
|
)
|
|
json.dump(cfg, tmp, indent=2)
|
|
tmp.close()
|
|
print(f" runtime config: {tmp.name}")
|
|
|
|
# Import LightX2V after env is set.
|
|
import sys
|
|
if "/app" not in sys.path:
|
|
sys.path.insert(0, "/app")
|
|
from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell
|
|
from lightx2v.infer import init_runner
|
|
from lightx2v.utils.input_info import (
|
|
init_empty_input_info,
|
|
update_input_info_from_dict,
|
|
)
|
|
from lightx2v.utils.set_config import set_config
|
|
|
|
_patch_fp8_scaled_mm_for_blackwell()
|
|
|
|
args = build_args(base_dir, tmp.name)
|
|
print(f"\n=== set_config + init_runner ===", flush=True)
|
|
torch.cuda.reset_peak_memory_stats()
|
|
t_load = time.perf_counter()
|
|
config = set_config(args)
|
|
runner = init_runner(config)
|
|
load_time = time.perf_counter() - t_load
|
|
vram_load = torch.cuda.memory_allocated()
|
|
print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True)
|
|
|
|
# GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match.
|
|
os.environ["DTYPE"] = "FP16"
|
|
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE
|
|
GET_DTYPE.cache_clear()
|
|
|
|
# Reuse the patch methods from Wan22Pipeline as standalone functions
|
|
# by constructing a minimal object with ._runner.
|
|
from server.video_models.wan22 import Wan22Pipeline
|
|
|
|
shim = Wan22Pipeline.__new__(Wan22Pipeline)
|
|
shim._runner = runner
|
|
Wan22Pipeline._patch_t5_dtype_for_gguf(shim)
|
|
Wan22Pipeline._patch_vae_dtype_for_gguf(shim)
|
|
Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim)
|
|
|
|
# Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model).
|
|
# For dense wan2.2, the model is directly at runner.model — patch it explicitly.
|
|
n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model)
|
|
print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)")
|
|
# Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc).
|
|
n_extra = _cast_all_fp32_tensors(runner.model)
|
|
print(f"[patch] cast {n_extra} additional fp32 tensors to fp16")
|
|
|
|
input_info = init_empty_input_info(args.task, args.support_tasks)
|
|
|
|
print(f"\n=== generate ===", flush=True)
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
|
out_path = tf.name
|
|
|
|
update_input_info_from_dict(
|
|
input_info,
|
|
{
|
|
"seed": SEED,
|
|
"prompt": PROMPT,
|
|
"negative_prompt": "",
|
|
"image_path": SAMPLE_AVATAR,
|
|
"save_result_path": out_path,
|
|
"target_video_length": 81,
|
|
"return_result_tensor": False,
|
|
},
|
|
)
|
|
|
|
t_gen = time.perf_counter()
|
|
runner.run_pipeline(input_info)
|
|
gen_time = time.perf_counter() - t_gen
|
|
peak_vram = torch.cuda.max_memory_allocated()
|
|
print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True)
|
|
|
|
if os.path.exists(out_path):
|
|
size = os.path.getsize(out_path)
|
|
OUTPUT.write_bytes(Path(out_path).read_bytes())
|
|
os.remove(out_path)
|
|
print(f"[output] wrote {OUTPUT} ({size} bytes)")
|
|
else:
|
|
print(f"[output] no file at {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|