first stab at adding video
This commit is contained in:
@@ -0,0 +1,423 @@
|
||||
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
|
||||
from lightx2v.utils.set_config import set_config
|
||||
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||
from lightx2v.infer import init_runner
|
||||
|
||||
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
|
||||
config = set_config(args)
|
||||
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||
runner = init_runner(config) # loads all weights — done ONCE
|
||||
|
||||
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
|
||||
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
|
||||
# LoRA hot-swap:
|
||||
runner.switch_lora(lora_path, strength) # swap in
|
||||
runner.switch_lora("", 0.0) # remove
|
||||
|
||||
Model weights are loaded once at construction and held resident across turns
|
||||
so reflective mode doesn't re-pay the load cost each reply.
|
||||
|
||||
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with fp8.
|
||||
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
|
||||
the fp8 e4m3 4-step distilled high/low
|
||||
noise DIT checkpoints (~15 GB each).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from server.video import LoRASpec
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
|
||||
# files from the distill repo replace the DIT weights entirely. We must keep
|
||||
# the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
"high_noise_model/*.safetensors",
|
||||
"low_noise_model/*.safetensors",
|
||||
"assets/*",
|
||||
"examples/*",
|
||||
"nohup.out",
|
||||
"*.md",
|
||||
]
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
``generate_i2v`` runs one inference turn against the already-loaded runner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_repo: str,
|
||||
fp8_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.fp8_repo = fp8_repo
|
||||
self.config_json_template = config_json
|
||||
self.model_cls = model_cls
|
||||
self.resolution = resolution
|
||||
self.fps = fps
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
|
||||
|
||||
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
||||
self._runtime_json_path = self._build_runtime_config()
|
||||
|
||||
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
|
||||
args = self._build_args(
|
||||
model_cls=model_cls,
|
||||
model_path=self._model_root,
|
||||
config_json=self._runtime_json_path,
|
||||
)
|
||||
|
||||
# 4. set_config → init_runner. Runner construction triggers weight load.
|
||||
# Imports are scoped here so ``import server.video_models.wan22``
|
||||
# never pulls in lightx2v (tests can import this module on CPU).
|
||||
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
||||
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
||||
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
||||
|
||||
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
||||
model_cls, self._model_root)
|
||||
self._config = set_config(args)
|
||||
|
||||
self._input_info_template = init_empty_input_info(
|
||||
args.task, args.support_tasks
|
||||
)
|
||||
|
||||
log.info("LightX2V init_runner — loading weights (this takes a while)...")
|
||||
self._runner = init_runner(self._config)
|
||||
log.info("LightX2V runner loaded; weights resident.")
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _ensure_base_repo(base_repo: str) -> str:
|
||||
"""Return a local directory containing the Wan2.2 base support files.
|
||||
|
||||
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
||||
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
||||
shards (they're replaced by the fp8 files).
|
||||
"""
|
||||
if os.path.isdir(base_repo):
|
||||
return base_repo
|
||||
from huggingface_hub import snapshot_download
|
||||
log.info("Downloading Wan2.2 base support files from %s "
|
||||
"(skipping bf16 DIT shards)...", base_repo)
|
||||
return snapshot_download(
|
||||
repo_id=base_repo,
|
||||
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
|
||||
|
||||
- If ``fp8_repo`` is a local directory, expect both files inside it.
|
||||
- Otherwise treat it as a HF repo id and download only the two files
|
||||
we need (not the ~150 GB of other variants in that repo).
|
||||
"""
|
||||
if not fp8_repo:
|
||||
raise ValueError("fp8_repo must be a HF repo id or local directory.")
|
||||
if os.path.isdir(fp8_repo):
|
||||
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
|
||||
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
raise FileNotFoundError(
|
||||
f"fp8 checkpoints not found in {fp8_repo}: expected "
|
||||
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
|
||||
)
|
||||
return high, low
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
|
||||
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
|
||||
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
|
||||
return high, low
|
||||
|
||||
def _build_runtime_config(self) -> str:
|
||||
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
||||
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._fp8_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._fp8_low
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
prefix="wan22_fp8_", suffix=".json",
|
||||
mode="w", delete=False, encoding="utf-8",
|
||||
)
|
||||
json.dump(cfg, tmp, indent=2)
|
||||
tmp.close()
|
||||
log.info("Runtime LightX2V config: %s", tmp.name)
|
||||
return tmp.name
|
||||
|
||||
@staticmethod
|
||||
def _build_args(
|
||||
*, model_cls: str, model_path: str, config_json: str
|
||||
) -> argparse.Namespace:
|
||||
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
|
||||
``set_config`` finds the attributes it expects. We only customize the
|
||||
model/task/path fields; everything else stays at the CLI defaults.
|
||||
"""
|
||||
return argparse.Namespace(
|
||||
seed=42,
|
||||
model_cls=model_cls,
|
||||
task="i2v",
|
||||
support_tasks=[],
|
||||
model_path=model_path,
|
||||
sf_model_path=None,
|
||||
config_json=config_json,
|
||||
use_prompt_enhancer=False,
|
||||
prompt="",
|
||||
negative_prompt="",
|
||||
image_path="",
|
||||
last_frame_path="",
|
||||
audio_path="",
|
||||
image_strength="1.0",
|
||||
image_frame_idx="",
|
||||
src_ref_images=None,
|
||||
src_video=None,
|
||||
src_mask=None,
|
||||
src_pose_path=None,
|
||||
src_face_path=None,
|
||||
src_bg_path=None,
|
||||
src_mask_path=None,
|
||||
pose=None,
|
||||
action_path=None,
|
||||
action_ckpt=None,
|
||||
save_result_path=None,
|
||||
return_result_tensor=False,
|
||||
target_shape=[],
|
||||
target_video_length=81,
|
||||
aspect_ratio="",
|
||||
video_path=None,
|
||||
sr_ratio=2.0,
|
||||
)
|
||||
|
||||
# --- LoRA --------------------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
|
||||
|
||||
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
|
||||
to route the LoRA to the correct expert.
|
||||
|
||||
With ``lazy_load`` the DIT models are ``None`` at this point, so
|
||||
runtime ``switch_lora`` is impossible. Instead we inject
|
||||
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
|
||||
the LoRAs are applied when the models materialise on first inference.
|
||||
|
||||
Without ``lazy_load`` (models already resident) we call
|
||||
``switch_lora`` with explicit high/low keyword args.
|
||||
"""
|
||||
if not specs:
|
||||
return
|
||||
|
||||
# Resolve every path up-front (may trigger HF download).
|
||||
resolved: list[tuple["LoRASpec", str]] = []
|
||||
for spec in specs:
|
||||
local_path = self._resolve_lora_path(spec.path)
|
||||
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
|
||||
spec.name or spec.path, spec.weight, spec.target,
|
||||
local_path)
|
||||
resolved.append((spec, local_path))
|
||||
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
# Build the lora_configs list that LightX2V's lazy-load path
|
||||
# reads inside MultiDistillModelStruct.infer().
|
||||
lora_cfgs = []
|
||||
for spec, local_path in resolved:
|
||||
# LightX2V expects name "high_noise_model" / "low_noise_model"
|
||||
cfg_name = {
|
||||
"high_noise": "high_noise_model",
|
||||
"low_noise": "low_noise_model",
|
||||
}.get(spec.target)
|
||||
if cfg_name is None:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
lora_cfgs.append({
|
||||
"name": cfg_name,
|
||||
"path": local_path,
|
||||
"strength": spec.weight,
|
||||
})
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
else:
|
||||
# Models are loaded — use runtime hot-swap.
|
||||
high_path = high_strength = None
|
||||
low_path = low_strength = None
|
||||
for spec, local_path in resolved:
|
||||
if spec.target == "high_noise":
|
||||
high_path, high_strength = local_path, spec.weight
|
||||
elif spec.target == "low_noise":
|
||||
low_path, low_strength = local_path, spec.weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
|
||||
kwargs: dict = {}
|
||||
if high_path is not None:
|
||||
kwargs["high_lora_path"] = high_path
|
||||
kwargs["high_lora_strength"] = high_strength
|
||||
if low_path is not None:
|
||||
kwargs["low_lora_path"] = low_path
|
||||
kwargs["low_lora_strength"] = low_strength
|
||||
ok = self._runner.switch_lora(**kwargs)
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"runner.switch_lora returned False. Check that your "
|
||||
"LightX2V build supports runtime LoRA updates for "
|
||||
f"{self.model_cls}.")
|
||||
|
||||
self._applied_loras = list(specs)
|
||||
|
||||
def unload_loras(self) -> None:
|
||||
"""Remove all currently applied LoRAs."""
|
||||
if not self._applied_loras:
|
||||
return
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
# If models were materialised, drop them so the next inference
|
||||
# recreates them without LoRAs.
|
||||
model_struct = getattr(self._runner, "model", None)
|
||||
if model_struct is not None and hasattr(model_struct, "model"):
|
||||
for i in range(len(model_struct.model)):
|
||||
model_struct.model[i] = None
|
||||
else:
|
||||
self._runner.switch_lora("", 0.0)
|
||||
self._applied_loras = []
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_path(path: str) -> str:
|
||||
"""Resolve a LoRA path. Supports:
|
||||
- Absolute/relative local paths (returned as-is if the file exists)
|
||||
- ``repo_id:filename`` HuggingFace references
|
||||
"""
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
if ":" in path and not path.startswith(("/", "./")):
|
||||
repo_id, filename = path.split(":", 1)
|
||||
from huggingface_hub import hf_hub_download
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
return path
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
def generate_i2v(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
seconds: int,
|
||||
seed: int | None = None,
|
||||
negative_prompt: str = "",
|
||||
) -> np.ndarray:
|
||||
"""Run image-to-video inference and return decoded frames.
|
||||
|
||||
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**31 - 1)
|
||||
|
||||
# Wan2.2 target_video_length is "frames including the conditioning
|
||||
# frame", so N seconds → N*fps + 1.
|
||||
target_frames = seconds * self.fps + 1
|
||||
|
||||
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||
out_path = tf.name
|
||||
try:
|
||||
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s",
|
||||
prompt[:80], seconds, seed, out_path)
|
||||
update_input_info_from_dict(
|
||||
self._input_info_template,
|
||||
{
|
||||
"seed": seed,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image_path": image_path,
|
||||
"save_result_path": out_path,
|
||||
"target_video_length": target_frames,
|
||||
"return_result_tensor": False,
|
||||
},
|
||||
)
|
||||
self._runner.run_pipeline(self._input_info_template)
|
||||
return _read_mp4_to_frames(out_path)
|
||||
finally:
|
||||
try:
|
||||
os.remove(out_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# --- MP4 decoding helper ------------------------------------------------------
|
||||
|
||||
def _read_mp4_to_frames(path: str) -> np.ndarray:
|
||||
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
|
||||
try:
|
||||
import imageio.v3 as iio # type: ignore[import-not-found]
|
||||
frames = iio.imread(path, plugin="pyav")
|
||||
arr = np.asarray(frames)
|
||||
if arr.ndim == 3:
|
||||
arr = arr[None, ...]
|
||||
return arr.astype(np.uint8)
|
||||
except Exception as e: # pragma: no cover - fallback path
|
||||
log.warning("imageio decode failed (%s); falling back to cv2", e)
|
||||
import cv2 # type: ignore[import-not-found]
|
||||
cap = cv2.VideoCapture(path)
|
||||
frames: list[np.ndarray] = []
|
||||
while True:
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
break
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
cap.release()
|
||||
if not frames:
|
||||
raise RuntimeError(f"Failed to decode any frames from {path}")
|
||||
return np.stack(frames, axis=0).astype(np.uint8)
|
||||
Reference in New Issue
Block a user