"""Wan2.2-Lightning image-to-video wrapper via LightX2V. This wrapper targets LightX2V's actual Python entry points (verified against the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main): from lightx2v.utils.set_config import set_config from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict from lightx2v.infer import init_runner args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...) config = set_config(args) input_info = init_empty_input_info(args.task, args.support_tasks) runner = init_runner(config) # loads all weights — done ONCE update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...}) runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path # LoRA hot-swap: runner.switch_lora(lora_path, strength) # swap in runner.switch_lora("", 0.0) # remove Model weights are loaded once at construction and held resident across turns so reflective mode doesn't re-pay the load cost each reply. Two HuggingFace repos are consumed on first run (cached under HF_HOME): - Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only. The bf16 DIT shards under high_noise_model/ and low_noise_model/ are SKIPPED via ignore_patterns — we replace them with quantised checkpoints from dit_repo. - dit_repo (configurable) — quantised DIT checkpoints. Supported formats: * fp8 safetensors (lightx2v/Wan2.2-Distill-Models) * GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF) """ from __future__ import annotations import argparse import json import logging import os import random import tempfile from typing import TYPE_CHECKING import numpy as np if TYPE_CHECKING: from server.video import LoRASpec log = logging.getLogger(__name__) # --- fp8 distill filenames -------------------------------------------------- FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" # --- GGUF filenames (QuantStack layout: HighNoise/.gguf) --------------- GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf" GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf" # --- fp8 T5 encoder (lightx2v/Encoders repo) -------------------------------- T5_FP8_REPO = "lightx2v/Encoders" T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors" # The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the # T5/VAE/tokenizer support files (~12 GB). We only need the latter — the # quantised files from dit_repo replace the DIT weights entirely. We must # keep the config.json / index.json metadata under high_noise_model/ and # low_noise_model/ (LightX2V's set_config reads architecture params like # ``dim`` from them) and the tokenizer files under google/. BASE_REPO_IGNORE_PATTERNS = [ "high_noise_model/*.safetensors", "low_noise_model/*.safetensors", "assets/*", "examples/*", "nohup.out", "*.md", ] def _patch_fp8_scaled_mm_for_blackwell() -> None: """Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell. sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet. PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures including Blackwell. This patch is idempotent. """ try: import sgl_kernel # type: ignore[import-not-found] except ImportError: return # no sgl_kernel → fp8 T5 not in use if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False): return import torch if not torch.cuda.is_available(): return cap = torch.cuda.get_device_capability() if cap[0] < 12: return # only patch on Blackwell+ _orig = sgl_kernel.fp8_scaled_mm def _torch_fp8_scaled_mm( a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, out_dtype: torch.dtype, bias: torch.Tensor | None = None, ) -> torch.Tensor: # torch._scaled_mm expects (M,K) @ (N,K).t() with: # scale_a: scalar or (M,1) # scale_b: scalar or (1,N) # sgl_kernel provides scale_b as (N,1) — transpose it. if b_scale.dim() == 2 and b_scale.shape[1] == 1: b_scale = b_scale.t() # _scaled_mm requires B to be column-major (stride(0)==1). bt = b.t().contiguous().t() out = torch._scaled_mm( a, bt, scale_a=a_scale, scale_b=b_scale, out_dtype=out_dtype, bias=bias, ) return out sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm sgl_kernel._fp8_patched_for_blackwell = True log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap) class Wan22Pipeline: """Wrapper around LightX2V's Wan2.2 MoE distill runner. Supports two DIT quantisation formats: * **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from ``lightx2v/Wan2.2-Distill-Models``) * **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level, from ``QuantStack/Wan2.2-I2V-A14B-GGUF``) Constructor downloads (if needed) both HF repos, writes a runtime JSON config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``. ``generate_i2v`` runs one inference turn against the already-loaded runner. """ def __init__( self, base_repo: str, dit_repo: str, config_json: str, model_cls: str = "wan2.2_moe_distill", resolution: int = 480, fps: int = 16, dit_quant_scheme: str = "fp8-sgl", t5_quantized: bool = False, ): self.base_repo = base_repo self.dit_repo = dit_repo self.config_json_template = config_json self.model_cls = model_cls self.resolution = resolution self.fps = fps self.dit_quant_scheme = dit_quant_scheme self.t5_quantized = t5_quantized self._applied_loras: list[LoRASpec] = [] self._is_gguf = dit_quant_scheme.startswith("gguf-") # 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts. self._model_root = self._ensure_base_repo(base_repo) self._dit_high, self._dit_low = self._ensure_dit_checkpoints( dit_repo, dit_quant_scheme, ) self._t5_fp8_ckpt = ( self._ensure_t5_fp8() if t5_quantized else None ) # 2. Materialize a runtime JSON config with absolute ckpt paths. self._runtime_json_path = self._build_runtime_config() # 3. Build the argparse-like namespace LightX2V.set_config() expects. args = self._build_args( model_cls=model_cls, model_path=self._model_root, config_json=self._runtime_json_path, ) # 4. Import LightX2V (scoped here so ``import server.video_models.wan22`` # never pulls in lightx2v — tests can import this module on CPU). from lightx2v.utils.set_config import set_config # type: ignore[import-not-found] from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found] from lightx2v.infer import init_runner # type: ignore[import-not-found] _patch_fp8_scaled_mm_for_blackwell() # 5. Load all models under default DTYPE=BF16 so T5 (which is # hardcoded to bf16 weights) initialises its offload buffers # correctly. We flip to FP16 *after* init_runner completes. log.info("LightX2V set_config (model_cls=%s, model_path=%s)", model_cls, self._model_root) self._config = set_config(args) self._input_info_template = init_empty_input_info( args.task, args.support_tasks ) log.info("LightX2V init_runner — loading weights (this takes a while)...") self._runner = init_runner(self._config) log.info("LightX2V runner loaded; weights resident.") # 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT # dequantises to fp16, and many intermediate tensors inside the # DIT forward pass are allocated via GET_DTYPE(). The T5 encoder # is wrapped to temporarily restore BF16 during its forward. if self._is_gguf: os.environ["DTYPE"] = "FP16" from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found] GET_DTYPE.cache_clear() log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE()) self._patch_t5_dtype_for_gguf() # --- GGUF dtype compatibility patch ---------------------------------------- def _patch_t5_dtype_for_gguf(self) -> None: """Wrap the T5 encoder so it temporarily restores DTYPE=BF16. The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload path breaks because intermediate tensor dtypes no longer match the bf16 weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE() back to bf16, then restore fp16 before the DIT runs. """ import os import types from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found] runner = self._runner orig_run_text_encoder = runner.run_text_encoder.__func__ def bf16_text_encoder(self_runner, *args, **kwargs): # Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights. os.environ["DTYPE"] = "BF16" GET_DTYPE.cache_clear() GET_SENSITIVE_DTYPE.cache_clear() try: result = orig_run_text_encoder(self_runner, *args, **kwargs) finally: # Restore FP16 for the DIT / rest of the pipeline. os.environ["DTYPE"] = "FP16" GET_DTYPE.cache_clear() GET_SENSITIVE_DTYPE.cache_clear() return result runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner) log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.") # --- Weight provisioning ------------------------------------------------- @staticmethod def _ensure_base_repo(base_repo: str) -> str: """Return a local directory containing the Wan2.2 base support files. If ``base_repo`` is already a local directory, use it as-is. Otherwise snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT shards (they're replaced by the quantised files). """ if os.path.isdir(base_repo): return base_repo from huggingface_hub import snapshot_download log.info("Downloading Wan2.2 base support files from %s " "(skipping bf16 DIT shards)...", base_repo) return snapshot_download( repo_id=base_repo, ignore_patterns=BASE_REPO_IGNORE_PATTERNS, ) @staticmethod def _ensure_dit_checkpoints( dit_repo: str, dit_quant_scheme: str, ) -> tuple[str, str]: """Return (high_noise_path, low_noise_path) for the DIT pair. Supports both fp8 safetensors and GGUF formats. """ if not dit_repo: raise ValueError("dit_repo must be a HF repo id or local directory.") is_gguf = dit_quant_scheme.startswith("gguf-") if is_gguf: # Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M" quant = dit_quant_scheme.replace("gguf-", "") high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant) low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant) else: high_file = FP8_HIGH_NOISE_FILE low_file = FP8_LOW_NOISE_FILE # Local directory? if os.path.isdir(dit_repo): high = os.path.join(dit_repo, high_file) low = os.path.join(dit_repo, low_file) if not (os.path.isfile(high) and os.path.isfile(low)): raise FileNotFoundError( f"DIT checkpoints not found in {dit_repo}: expected " f"{high_file} and {low_file}" ) return high, low # HuggingFace download. from huggingface_hub import hf_hub_download log.info("Downloading %s DIT checkpoints from %s ...", dit_quant_scheme, dit_repo) high = hf_hub_download(repo_id=dit_repo, filename=high_file) low = hf_hub_download(repo_id=dit_repo, filename=low_file) return high, low @staticmethod def _ensure_t5_fp8() -> str: """Download the fp8 T5 encoder from lightx2v/Encoders (if not cached). Returns the local path to the safetensors file (~6 GB). """ from huggingface_hub import hf_hub_download log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO) return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE) def _build_runtime_config(self) -> str: """Load the template JSON, inject absolute ckpt paths, persist to temp.""" with open(self.config_json_template, "r", encoding="utf-8") as f: cfg = json.load(f) # Drop editorial comments before passing to LightX2V. cfg.pop("_comment", None) cfg["high_noise_quantized_ckpt"] = self._dit_high cfg["low_noise_quantized_ckpt"] = self._dit_low cfg.setdefault("fps", self.fps) # T5 fp8 quantization. if self._t5_fp8_ckpt: cfg["t5_quantized"] = True cfg["t5_quant_scheme"] = "fp8-sgl" cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt tmp = tempfile.NamedTemporaryFile( prefix="wan22_dit_", suffix=".json", mode="w", delete=False, encoding="utf-8", ) json.dump(cfg, tmp, indent=2) tmp.close() log.info("Runtime LightX2V config: %s", tmp.name) return tmp.name @staticmethod def _build_args( *, model_cls: str, model_path: str, config_json: str ) -> argparse.Namespace: """Mirror every field from ``lightx2v.infer.main``'s argparse so ``set_config`` finds the attributes it expects. We only customize the model/task/path fields; everything else stays at the CLI defaults. """ return argparse.Namespace( seed=42, model_cls=model_cls, task="i2v", support_tasks=[], model_path=model_path, sf_model_path=None, config_json=config_json, use_prompt_enhancer=False, prompt="", negative_prompt="", image_path="", last_frame_path="", audio_path="", image_strength="1.0", image_frame_idx="", src_ref_images=None, src_video=None, src_mask=None, src_pose_path=None, src_face_path=None, src_bg_path=None, src_mask_path=None, pose=None, action_path=None, action_ckpt=None, save_result_path=None, return_result_tensor=False, target_shape=[], target_video_length=81, aspect_ratio="", video_path=None, sr_ratio=2.0, ) # --- LoRA -------------------------------------------------------------- def load_loras(self, specs: list["LoRASpec"]) -> None: """Apply LoRAs to the Wan2.2 MoE distill pipeline. Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"`` to route the LoRA to the correct expert. With ``lazy_load`` the DIT models are ``None`` at this point, so runtime ``switch_lora`` is impossible. Instead we inject ``lora_configs`` + ``lora_dynamic_apply`` into the runner config so the LoRAs are applied when the models materialise on first inference. Without ``lazy_load`` (models already resident) we call ``switch_lora`` with explicit high/low keyword args. """ if not specs: return # Resolve every path up-front (may trigger HF download). resolved: list[tuple["LoRASpec", str]] = [] for spec in specs: local_path = self._resolve_lora_path(spec.path) log.info(" LoRA %s → strength=%.2f target=%s (%s)", spec.name or spec.path, spec.weight, spec.target, local_path) resolved.append((spec, local_path)) lazy = self._config.get("lazy_load", False) if lazy: # Build the lora_configs list that LightX2V's lazy-load path # reads inside MultiDistillModelStruct.infer(). lora_cfgs = [] for spec, local_path in resolved: # LightX2V expects name "high_noise_model" / "low_noise_model" cfg_name = { "high_noise": "high_noise_model", "low_noise": "low_noise_model", }.get(spec.target) if cfg_name is None: raise ValueError( f"LoRA target must be 'high_noise' or 'low_noise', " f"got {spec.target!r}") lora_cfgs.append({ "name": cfg_name, "path": local_path, "strength": spec.weight, }) self._runner.set_config({ "lora_configs": lora_cfgs, "lora_dynamic_apply": True, }) else: # Models are loaded — use runtime hot-swap. high_path = high_strength = None low_path = low_strength = None for spec, local_path in resolved: if spec.target == "high_noise": high_path, high_strength = local_path, spec.weight elif spec.target == "low_noise": low_path, low_strength = local_path, spec.weight else: raise ValueError( f"LoRA target must be 'high_noise' or 'low_noise', " f"got {spec.target!r}") kwargs: dict = {} if high_path is not None: kwargs["high_lora_path"] = high_path kwargs["high_lora_strength"] = high_strength if low_path is not None: kwargs["low_lora_path"] = low_path kwargs["low_lora_strength"] = low_strength ok = self._runner.switch_lora(**kwargs) if not ok: raise RuntimeError( "runner.switch_lora returned False. Check that your " "LightX2V build supports runtime LoRA updates for " f"{self.model_cls}.") self._applied_loras = list(specs) def unload_loras(self) -> None: """Remove all currently applied LoRAs.""" if not self._applied_loras: return lazy = self._config.get("lazy_load", False) if lazy: self._runner.set_config({ "lora_configs": None, "lora_dynamic_apply": False, }) # If models were materialised, drop them so the next inference # recreates them without LoRAs. model_struct = getattr(self._runner, "model", None) if model_struct is not None and hasattr(model_struct, "model"): for i in range(len(model_struct.model)): model_struct.model[i] = None else: self._runner.switch_lora("", 0.0) self._applied_loras = [] @staticmethod def _resolve_lora_path(path: str) -> str: """Resolve a LoRA path. Supports: - Absolute/relative local paths (returned as-is if the file exists) - ``repo_id:filename`` HuggingFace references """ if os.path.isfile(path): return path if ":" in path and not path.startswith(("/", "./")): repo_id, filename = path.split(":", 1) from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=repo_id, filename=filename) return path # --- Inference --------------------------------------------------------- def generate_i2v( self, image_path: str, prompt: str, seconds: int, seed: int | None = None, negative_prompt: str = "", ) -> np.ndarray: """Run image-to-video inference and return decoded frames. Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB. """ if seed is None: seed = random.randint(0, 2**31 - 1) # Wan2.2 target_video_length is "frames including the conditioning # frame", so N seconds → N*fps + 1. target_frames = seconds * self.fps + 1 from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf: out_path = tf.name try: log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s", prompt[:80], seconds, seed, out_path) update_input_info_from_dict( self._input_info_template, { "seed": seed, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_result_path": out_path, "target_video_length": target_frames, "return_result_tensor": False, }, ) self._runner.run_pipeline(self._input_info_template) return _read_mp4_to_frames(out_path) finally: try: os.remove(out_path) except OSError: pass # --- MP4 decoding helper ------------------------------------------------------ def _read_mp4_to_frames(path: str) -> np.ndarray: """Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``.""" try: import imageio.v3 as iio # type: ignore[import-not-found] frames = iio.imread(path, plugin="pyav") arr = np.asarray(frames) if arr.ndim == 3: arr = arr[None, ...] return arr.astype(np.uint8) except Exception as e: # pragma: no cover - fallback path log.warning("imageio decode failed (%s); falling back to cv2", e) import cv2 # type: ignore[import-not-found] cap = cv2.VideoCapture(path) frames: list[np.ndarray] = [] while True: ok, frame = cap.read() if not ok: break frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cap.release() if not frames: raise RuntimeError(f"Failed to decode any frames from {path}") return np.stack(frames, axis=0).astype(np.uint8)