"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V. This wrapper targets LightX2V's actual Python entry points (verified against the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main): from lightx2v.utils.set_config import set_config from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict from lightx2v.infer import init_runner args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...) config = set_config(args) input_info = init_empty_input_info(args.task, args.support_tasks) runner = init_runner(config) # loads all weights — done ONCE update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...}) runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path # LoRA hot-swap: runner.switch_lora(lora_path, strength) # swap in runner.switch_lora("", 0.0) # remove Model weights are loaded once at construction and held resident across turns so reflective mode doesn't re-pay the load cost each reply. Two HuggingFace repos are consumed on first run (cached under HF_HOME): - Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only. The bf16 DIT shards under high_noise_model/ and low_noise_model/ are SKIPPED via ignore_patterns — we replace them with fp8. - lightx2v/Wan2.2-Distill-Models — exactly two safetensors files: the fp8 e4m3 4-step distilled high/low noise DIT checkpoints (~15 GB each). """ from __future__ import annotations import argparse import json import logging import os import random import tempfile from typing import TYPE_CHECKING import numpy as np if TYPE_CHECKING: from server.video import LoRASpec log = logging.getLogger(__name__) FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" # The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the # T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8 # files from the distill repo replace the DIT weights entirely. We must keep # the config.json / index.json metadata under high_noise_model/ and # low_noise_model/ (LightX2V's set_config reads architecture params like # ``dim`` from them) and the tokenizer files under google/. BASE_REPO_IGNORE_PATTERNS = [ "high_noise_model/*.safetensors", "low_noise_model/*.safetensors", "assets/*", "examples/*", "nohup.out", "*.md", ] class Wan22Pipeline: """Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights. Constructor downloads (if needed) both HF repos, writes a runtime JSON config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``. ``generate_i2v`` runs one inference turn against the already-loaded runner. """ def __init__( self, base_repo: str, fp8_repo: str, config_json: str, model_cls: str = "wan2.2_moe_distill", resolution: int = 480, fps: int = 16, ): self.base_repo = base_repo self.fp8_repo = fp8_repo self.config_json_template = config_json self.model_cls = model_cls self.resolution = resolution self.fps = fps self._applied_loras: list[LoRASpec] = [] # 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts. self._model_root = self._ensure_base_repo(base_repo) self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo) # 2. Materialize a runtime JSON config with absolute ckpt paths. self._runtime_json_path = self._build_runtime_config() # 3. Build the argparse-like namespace LightX2V.set_config() expects. args = self._build_args( model_cls=model_cls, model_path=self._model_root, config_json=self._runtime_json_path, ) # 4. set_config → init_runner. Runner construction triggers weight load. # Imports are scoped here so ``import server.video_models.wan22`` # never pulls in lightx2v (tests can import this module on CPU). from lightx2v.utils.set_config import set_config # type: ignore[import-not-found] from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found] from lightx2v.infer import init_runner # type: ignore[import-not-found] log.info("LightX2V set_config (model_cls=%s, model_path=%s)", model_cls, self._model_root) self._config = set_config(args) self._input_info_template = init_empty_input_info( args.task, args.support_tasks ) log.info("LightX2V init_runner — loading weights (this takes a while)...") self._runner = init_runner(self._config) log.info("LightX2V runner loaded; weights resident.") # --- Weight provisioning ------------------------------------------------- @staticmethod def _ensure_base_repo(base_repo: str) -> str: """Return a local directory containing the Wan2.2 base support files. If ``base_repo`` is already a local directory, use it as-is. Otherwise snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT shards (they're replaced by the fp8 files). """ if os.path.isdir(base_repo): return base_repo from huggingface_hub import snapshot_download log.info("Downloading Wan2.2 base support files from %s " "(skipping bf16 DIT shards)...", base_repo) return snapshot_download( repo_id=base_repo, ignore_patterns=BASE_REPO_IGNORE_PATTERNS, ) @staticmethod def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]: """Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair. - If ``fp8_repo`` is a local directory, expect both files inside it. - Otherwise treat it as a HF repo id and download only the two files we need (not the ~150 GB of other variants in that repo). """ if not fp8_repo: raise ValueError("fp8_repo must be a HF repo id or local directory.") if os.path.isdir(fp8_repo): high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE) low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE) if not (os.path.isfile(high) and os.path.isfile(low)): raise FileNotFoundError( f"fp8 checkpoints not found in {fp8_repo}: expected " f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}" ) return high, low from huggingface_hub import hf_hub_download log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo) high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE) low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE) return high, low def _build_runtime_config(self) -> str: """Load the template JSON, inject absolute ckpt paths, persist to temp.""" with open(self.config_json_template, "r", encoding="utf-8") as f: cfg = json.load(f) # Drop editorial comments before passing to LightX2V. cfg.pop("_comment", None) cfg["high_noise_quantized_ckpt"] = self._fp8_high cfg["low_noise_quantized_ckpt"] = self._fp8_low cfg.setdefault("fps", self.fps) tmp = tempfile.NamedTemporaryFile( prefix="wan22_fp8_", suffix=".json", mode="w", delete=False, encoding="utf-8", ) json.dump(cfg, tmp, indent=2) tmp.close() log.info("Runtime LightX2V config: %s", tmp.name) return tmp.name @staticmethod def _build_args( *, model_cls: str, model_path: str, config_json: str ) -> argparse.Namespace: """Mirror every field from ``lightx2v.infer.main``'s argparse so ``set_config`` finds the attributes it expects. We only customize the model/task/path fields; everything else stays at the CLI defaults. """ return argparse.Namespace( seed=42, model_cls=model_cls, task="i2v", support_tasks=[], model_path=model_path, sf_model_path=None, config_json=config_json, use_prompt_enhancer=False, prompt="", negative_prompt="", image_path="", last_frame_path="", audio_path="", image_strength="1.0", image_frame_idx="", src_ref_images=None, src_video=None, src_mask=None, src_pose_path=None, src_face_path=None, src_bg_path=None, src_mask_path=None, pose=None, action_path=None, action_ckpt=None, save_result_path=None, return_result_tensor=False, target_shape=[], target_video_length=81, aspect_ratio="", video_path=None, sr_ratio=2.0, ) # --- LoRA -------------------------------------------------------------- def load_loras(self, specs: list["LoRASpec"]) -> None: """Apply LoRAs to the Wan2.2 MoE distill pipeline. Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"`` to route the LoRA to the correct expert. With ``lazy_load`` the DIT models are ``None`` at this point, so runtime ``switch_lora`` is impossible. Instead we inject ``lora_configs`` + ``lora_dynamic_apply`` into the runner config so the LoRAs are applied when the models materialise on first inference. Without ``lazy_load`` (models already resident) we call ``switch_lora`` with explicit high/low keyword args. """ if not specs: return # Resolve every path up-front (may trigger HF download). resolved: list[tuple["LoRASpec", str]] = [] for spec in specs: local_path = self._resolve_lora_path(spec.path) log.info(" LoRA %s → strength=%.2f target=%s (%s)", spec.name or spec.path, spec.weight, spec.target, local_path) resolved.append((spec, local_path)) lazy = self._config.get("lazy_load", False) if lazy: # Build the lora_configs list that LightX2V's lazy-load path # reads inside MultiDistillModelStruct.infer(). lora_cfgs = [] for spec, local_path in resolved: # LightX2V expects name "high_noise_model" / "low_noise_model" cfg_name = { "high_noise": "high_noise_model", "low_noise": "low_noise_model", }.get(spec.target) if cfg_name is None: raise ValueError( f"LoRA target must be 'high_noise' or 'low_noise', " f"got {spec.target!r}") lora_cfgs.append({ "name": cfg_name, "path": local_path, "strength": spec.weight, }) self._runner.set_config({ "lora_configs": lora_cfgs, "lora_dynamic_apply": True, }) else: # Models are loaded — use runtime hot-swap. high_path = high_strength = None low_path = low_strength = None for spec, local_path in resolved: if spec.target == "high_noise": high_path, high_strength = local_path, spec.weight elif spec.target == "low_noise": low_path, low_strength = local_path, spec.weight else: raise ValueError( f"LoRA target must be 'high_noise' or 'low_noise', " f"got {spec.target!r}") kwargs: dict = {} if high_path is not None: kwargs["high_lora_path"] = high_path kwargs["high_lora_strength"] = high_strength if low_path is not None: kwargs["low_lora_path"] = low_path kwargs["low_lora_strength"] = low_strength ok = self._runner.switch_lora(**kwargs) if not ok: raise RuntimeError( "runner.switch_lora returned False. Check that your " "LightX2V build supports runtime LoRA updates for " f"{self.model_cls}.") self._applied_loras = list(specs) def unload_loras(self) -> None: """Remove all currently applied LoRAs.""" if not self._applied_loras: return lazy = self._config.get("lazy_load", False) if lazy: self._runner.set_config({ "lora_configs": None, "lora_dynamic_apply": False, }) # If models were materialised, drop them so the next inference # recreates them without LoRAs. model_struct = getattr(self._runner, "model", None) if model_struct is not None and hasattr(model_struct, "model"): for i in range(len(model_struct.model)): model_struct.model[i] = None else: self._runner.switch_lora("", 0.0) self._applied_loras = [] @staticmethod def _resolve_lora_path(path: str) -> str: """Resolve a LoRA path. Supports: - Absolute/relative local paths (returned as-is if the file exists) - ``repo_id:filename`` HuggingFace references """ if os.path.isfile(path): return path if ":" in path and not path.startswith(("/", "./")): repo_id, filename = path.split(":", 1) from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=repo_id, filename=filename) return path # --- Inference --------------------------------------------------------- def generate_i2v( self, image_path: str, prompt: str, seconds: int, seed: int | None = None, negative_prompt: str = "", ) -> np.ndarray: """Run image-to-video inference and return decoded frames. Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB. """ if seed is None: seed = random.randint(0, 2**31 - 1) # Wan2.2 target_video_length is "frames including the conditioning # frame", so N seconds → N*fps + 1. target_frames = seconds * self.fps + 1 from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf: out_path = tf.name try: log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s", prompt[:80], seconds, seed, out_path) update_input_info_from_dict( self._input_info_template, { "seed": seed, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_result_path": out_path, "target_video_length": target_frames, "return_result_tensor": False, }, ) self._runner.run_pipeline(self._input_info_template) return _read_mp4_to_frames(out_path) finally: try: os.remove(out_path) except OSError: pass # --- MP4 decoding helper ------------------------------------------------------ def _read_mp4_to_frames(path: str) -> np.ndarray: """Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``.""" try: import imageio.v3 as iio # type: ignore[import-not-found] frames = iio.imread(path, plugin="pyav") arr = np.asarray(frames) if arr.ndim == 3: arr = arr[None, ...] return arr.astype(np.uint8) except Exception as e: # pragma: no cover - fallback path log.warning("imageio decode failed (%s); falling back to cv2", e) import cv2 # type: ignore[import-not-found] cap = cv2.VideoCapture(path) frames: list[np.ndarray] = [] while True: ok, frame = cap.read() if not ok: break frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cap.release() if not frames: raise RuntimeError(f"Failed to decode any frames from {path}") return np.stack(frames, axis=0).astype(np.uint8)