644 lines
25 KiB
Python
644 lines
25 KiB
Python
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
|
|
|
|
This wrapper targets LightX2V's actual Python entry points (verified against
|
|
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
|
|
|
from lightx2v.utils.set_config import set_config
|
|
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
|
from lightx2v.infer import init_runner
|
|
|
|
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
|
|
config = set_config(args)
|
|
input_info = init_empty_input_info(args.task, args.support_tasks)
|
|
runner = init_runner(config) # loads all weights — done ONCE
|
|
|
|
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
|
|
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
|
|
# LoRA hot-swap:
|
|
runner.switch_lora(lora_path, strength) # swap in
|
|
runner.switch_lora("", 0.0) # remove
|
|
|
|
Model weights are loaded once at construction and held resident across turns
|
|
so reflective mode doesn't re-pay the load cost each reply.
|
|
|
|
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
|
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
|
|
The bf16 DIT shards are SKIPPED via
|
|
ignore_patterns — replaced by the GGUF
|
|
checkpoint from dit_repo.
|
|
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
|
|
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import tempfile
|
|
from typing import TYPE_CHECKING
|
|
|
|
import numpy as np
|
|
|
|
if TYPE_CHECKING:
|
|
from server.video import LoRASpec
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
|
|
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
|
|
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
|
|
|
|
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
|
T5_FP8_REPO = "lightx2v/Encoders"
|
|
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
|
|
|
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
|
|
# tokenizer support files. We only need the latter — the GGUF from dit_repo
|
|
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
|
|
BASE_REPO_IGNORE_PATTERNS = [
|
|
"*.pt",
|
|
"diffusion_pytorch_model*.safetensors",
|
|
"assets/*",
|
|
"examples/*",
|
|
"nohup.out",
|
|
"*.md",
|
|
]
|
|
|
|
|
|
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
|
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
|
|
|
|
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
|
|
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
|
|
and are missed by the structured sweep. This generic traversal catches
|
|
them. Bounded depth + visited-set to avoid cycles.
|
|
"""
|
|
import torch
|
|
if visited is None:
|
|
visited = set()
|
|
obj_id = id(obj)
|
|
if obj_id in visited or depth > 6:
|
|
return 0
|
|
visited.add(obj_id)
|
|
n = 0
|
|
for attr_name in dir(obj):
|
|
if attr_name.startswith("__"):
|
|
continue
|
|
try:
|
|
val = getattr(obj, attr_name)
|
|
except Exception:
|
|
continue
|
|
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
|
try:
|
|
setattr(obj, attr_name, val.to(torch.float16))
|
|
n += 1
|
|
except Exception:
|
|
pass
|
|
elif hasattr(val, "__dict__") and not callable(val):
|
|
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
|
return n
|
|
|
|
|
|
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
|
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
|
|
|
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
|
|
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
|
|
including Blackwell. This patch is idempotent.
|
|
"""
|
|
try:
|
|
import sgl_kernel # type: ignore[import-not-found]
|
|
except ImportError:
|
|
return # no sgl_kernel → fp8 T5 not in use
|
|
|
|
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
|
|
return
|
|
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
cap = torch.cuda.get_device_capability()
|
|
if cap[0] < 12:
|
|
return # only patch on Blackwell+
|
|
|
|
_orig = sgl_kernel.fp8_scaled_mm
|
|
|
|
def _torch_fp8_scaled_mm(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
out_dtype: torch.dtype,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
|
|
# scale_a: scalar or (M,1)
|
|
# scale_b: scalar or (1,N)
|
|
# sgl_kernel provides scale_b as (N,1) — transpose it.
|
|
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
|
|
b_scale = b_scale.t()
|
|
# _scaled_mm requires B to be column-major (stride(0)==1).
|
|
bt = b.t().contiguous().t()
|
|
out = torch._scaled_mm(
|
|
a, bt,
|
|
scale_a=a_scale, scale_b=b_scale,
|
|
out_dtype=out_dtype, bias=bias,
|
|
)
|
|
return out
|
|
|
|
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
|
|
sgl_kernel._fp8_patched_for_blackwell = True
|
|
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
|
|
|
|
|
|
class Wan22Pipeline:
|
|
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
|
|
|
|
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
|
|
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
|
|
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
|
|
|
|
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
|
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
|
``generate_i2v`` runs one inference turn against the already-loaded runner.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_repo: str,
|
|
dit_repo: str,
|
|
config_json: str,
|
|
model_cls: str = "wan2.2",
|
|
resolution: int = 480,
|
|
fps: int = 16,
|
|
dit_quant_scheme: str = "gguf-Q8_0",
|
|
t5_quantized: bool = True,
|
|
):
|
|
self.base_repo = base_repo
|
|
self.dit_repo = dit_repo
|
|
self.config_json_template = config_json
|
|
self.model_cls = model_cls
|
|
self.resolution = resolution
|
|
self.fps = fps
|
|
self.dit_quant_scheme = dit_quant_scheme
|
|
self.t5_quantized = t5_quantized
|
|
self._applied_loras: list[LoRASpec] = []
|
|
|
|
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
|
if not self._is_gguf:
|
|
raise ValueError(
|
|
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
|
|
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
|
|
)
|
|
|
|
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
|
|
self._model_root = self._ensure_base_repo(base_repo)
|
|
self._dit_ckpt = self._ensure_dit_checkpoint(
|
|
dit_repo, dit_quant_scheme,
|
|
)
|
|
self._t5_fp8_ckpt = (
|
|
self._ensure_t5_fp8() if t5_quantized else None
|
|
)
|
|
|
|
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
|
self._runtime_json_path = self._build_runtime_config()
|
|
|
|
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
|
|
args = self._build_args(
|
|
model_cls=model_cls,
|
|
model_path=self._model_root,
|
|
config_json=self._runtime_json_path,
|
|
)
|
|
|
|
# 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
|
|
# never pulls in lightx2v — tests can import this module on CPU).
|
|
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
|
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
|
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
|
|
|
_patch_fp8_scaled_mm_for_blackwell()
|
|
|
|
# 5. Load all models under default DTYPE=BF16 so T5 (which is
|
|
# hardcoded to bf16 weights) initialises its offload buffers
|
|
# correctly. We flip to FP16 *after* init_runner completes.
|
|
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
|
model_cls, self._model_root)
|
|
self._config = set_config(args)
|
|
|
|
self._input_info_template = init_empty_input_info(
|
|
args.task, args.support_tasks
|
|
)
|
|
|
|
log.info("LightX2V init_runner — loading weights (this takes a while)...")
|
|
self._runner = init_runner(self._config)
|
|
log.info("LightX2V runner loaded; weights resident.")
|
|
|
|
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
|
|
# dequantises to fp16, and many intermediate tensors inside the
|
|
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
|
|
# is wrapped to temporarily restore BF16 during its forward.
|
|
if self._is_gguf:
|
|
os.environ["DTYPE"] = "FP16"
|
|
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
|
|
GET_DTYPE.cache_clear()
|
|
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
|
self._patch_t5_dtype_for_gguf()
|
|
self._patch_vae_dtype_for_gguf()
|
|
self._patch_dit_fp32_weights_for_gguf()
|
|
|
|
# --- GGUF dtype compatibility patch ----------------------------------------
|
|
|
|
def _patch_t5_dtype_for_gguf(self) -> None:
|
|
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
|
|
|
|
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
|
|
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
|
|
path breaks because intermediate tensor dtypes no longer match the bf16
|
|
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
|
|
back to bf16, then restore fp16 before the DIT runs.
|
|
"""
|
|
import os
|
|
import types
|
|
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
|
|
|
|
runner = self._runner
|
|
orig_run_text_encoder = runner.run_text_encoder.__func__
|
|
|
|
def bf16_text_encoder(self_runner, *args, **kwargs):
|
|
import torch
|
|
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
|
os.environ["DTYPE"] = "BF16"
|
|
GET_DTYPE.cache_clear()
|
|
GET_SENSITIVE_DTYPE.cache_clear()
|
|
try:
|
|
result = orig_run_text_encoder(self_runner, *args, **kwargs)
|
|
finally:
|
|
# Restore FP16 for the DIT / rest of the pipeline.
|
|
os.environ["DTYPE"] = "FP16"
|
|
GET_DTYPE.cache_clear()
|
|
GET_SENSITIVE_DTYPE.cache_clear()
|
|
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
|
|
def _to_fp16(x):
|
|
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
|
|
return x.to(torch.float16)
|
|
if isinstance(x, list):
|
|
return [_to_fp16(v) for v in x]
|
|
if isinstance(x, tuple):
|
|
return tuple(_to_fp16(v) for v in x)
|
|
if isinstance(x, dict):
|
|
return {k: _to_fp16(v) for k, v in x.items()}
|
|
return x
|
|
return _to_fp16(result)
|
|
|
|
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
|
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
|
|
|
def _patch_vae_dtype_for_gguf(self) -> None:
|
|
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
|
|
|
|
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
|
|
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
|
|
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
|
|
VAE is a plain float model (not quantized), simply converting its
|
|
weights to fp16 avoids both input-vs-weight mismatches and the need
|
|
for any runtime dtype juggling.
|
|
"""
|
|
import torch
|
|
|
|
runner = self._runner
|
|
for name in ("vae_encoder", "vae_decoder"):
|
|
mod = getattr(runner, name, None)
|
|
if mod is None:
|
|
continue
|
|
inner = getattr(mod, "model", mod)
|
|
if hasattr(inner, "to"):
|
|
inner.to(dtype=torch.float16)
|
|
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
|
|
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
|
|
# the first op upcasts fp16 latents back to fp32/bf16.
|
|
for attr in ("mean", "inv_std"):
|
|
t = getattr(mod, attr, None)
|
|
if isinstance(t, torch.Tensor):
|
|
setattr(mod, attr, t.to(torch.float16))
|
|
scale = getattr(mod, "scale", None)
|
|
if isinstance(scale, list):
|
|
mod.scale = [
|
|
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
|
|
for t in scale
|
|
]
|
|
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
|
|
|
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
|
"""Cast leftover fp32 DIT weights to fp16 (dense model).
|
|
|
|
GGUF dequantises the transformer blocks to fp16, but a handful of
|
|
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
|
|
loaded as fp32. That breaks the first conv in the DIT forward pass
|
|
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
|
|
directly at ``runner.model`` (no MoE wrapper). After the structured
|
|
pre/post weight sweep, we also run a recursive traversal to catch
|
|
fp32 conv3d biases etc. that live outside pre/post_weight.
|
|
"""
|
|
runner = self._runner
|
|
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
|
|
n_extra = _cast_all_fp32_tensors(runner.model)
|
|
log.info(
|
|
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
|
|
n_struct, n_extra,
|
|
)
|
|
|
|
@staticmethod
|
|
def _cast_fp32_dit_weights_in_model(m) -> int:
|
|
import torch
|
|
n_cast = 0
|
|
for weights_attr in ("pre_weight", "post_weight"):
|
|
w = getattr(m, weights_attr, None)
|
|
if w is None:
|
|
continue
|
|
for sub_name in dir(w):
|
|
if sub_name.startswith("_"):
|
|
continue
|
|
try:
|
|
sub = getattr(w, sub_name)
|
|
except Exception:
|
|
continue
|
|
if sub is None:
|
|
continue
|
|
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
|
t = getattr(sub, t_name, None)
|
|
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
|
casted = t.to(torch.float16)
|
|
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
|
try:
|
|
casted = casted.pin_memory()
|
|
except RuntimeError:
|
|
pass
|
|
setattr(sub, t_name, casted)
|
|
n_cast += 1
|
|
return n_cast
|
|
|
|
# --- Weight provisioning -------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _ensure_base_repo(base_repo: str) -> str:
|
|
"""Return a local directory containing the Wan2.2 base support files.
|
|
|
|
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
|
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
|
shards (they're replaced by the quantised files).
|
|
"""
|
|
if os.path.isdir(base_repo):
|
|
return base_repo
|
|
from huggingface_hub import snapshot_download
|
|
log.info("Downloading Wan2.2 base support files from %s "
|
|
"(skipping bf16 DIT shards)...", base_repo)
|
|
return snapshot_download(
|
|
repo_id=base_repo,
|
|
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
|
|
)
|
|
|
|
@staticmethod
|
|
def _ensure_dit_checkpoint(
|
|
dit_repo: str,
|
|
dit_quant_scheme: str,
|
|
) -> str:
|
|
"""Return the local path to the single dense GGUF DIT checkpoint."""
|
|
if not dit_repo:
|
|
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
|
if not dit_quant_scheme.startswith("gguf-"):
|
|
raise ValueError(
|
|
f"Only GGUF quant schemes are supported for dense 5B Turbo "
|
|
f"(got {dit_quant_scheme!r})."
|
|
)
|
|
|
|
quant = dit_quant_scheme.replace("gguf-", "")
|
|
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
|
|
|
|
if os.path.isdir(dit_repo):
|
|
path = os.path.join(dit_repo, filename)
|
|
if not os.path.isfile(path):
|
|
raise FileNotFoundError(
|
|
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
|
|
)
|
|
return path
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
log.info("Downloading %s DIT checkpoint from %s ...",
|
|
dit_quant_scheme, dit_repo)
|
|
return hf_hub_download(repo_id=dit_repo, filename=filename)
|
|
|
|
@staticmethod
|
|
def _ensure_t5_fp8() -> str:
|
|
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
|
|
|
|
Returns the local path to the safetensors file (~6 GB).
|
|
"""
|
|
from huggingface_hub import hf_hub_download
|
|
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
|
|
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
|
|
|
def _build_runtime_config(self) -> str:
|
|
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
|
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
|
cfg = json.load(f)
|
|
# Drop editorial comments before passing to LightX2V.
|
|
cfg.pop("_comment", None)
|
|
cfg["dit_quantized_ckpt"] = self._dit_ckpt
|
|
cfg.setdefault("fps", self.fps)
|
|
|
|
# T5 fp8 quantization.
|
|
if self._t5_fp8_ckpt:
|
|
cfg["t5_quantized"] = True
|
|
cfg["t5_quant_scheme"] = "fp8-sgl"
|
|
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
|
|
|
|
tmp = tempfile.NamedTemporaryFile(
|
|
prefix="wan22_dit_", suffix=".json",
|
|
mode="w", delete=False, encoding="utf-8",
|
|
)
|
|
json.dump(cfg, tmp, indent=2)
|
|
tmp.close()
|
|
log.info("Runtime LightX2V config: %s", tmp.name)
|
|
return tmp.name
|
|
|
|
@staticmethod
|
|
def _build_args(
|
|
*, model_cls: str, model_path: str, config_json: str
|
|
) -> argparse.Namespace:
|
|
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
|
|
``set_config`` finds the attributes it expects. We only customize the
|
|
model/task/path fields; everything else stays at the CLI defaults.
|
|
"""
|
|
return argparse.Namespace(
|
|
seed=42,
|
|
model_cls=model_cls,
|
|
task="i2v",
|
|
support_tasks=[],
|
|
model_path=model_path,
|
|
sf_model_path=None,
|
|
config_json=config_json,
|
|
use_prompt_enhancer=False,
|
|
prompt="",
|
|
negative_prompt="",
|
|
image_path="",
|
|
last_frame_path="",
|
|
audio_path="",
|
|
image_strength="1.0",
|
|
image_frame_idx="",
|
|
src_ref_images=None,
|
|
src_video=None,
|
|
src_mask=None,
|
|
src_pose_path=None,
|
|
src_face_path=None,
|
|
src_bg_path=None,
|
|
src_mask_path=None,
|
|
pose=None,
|
|
action_path=None,
|
|
action_ckpt=None,
|
|
save_result_path=None,
|
|
return_result_tensor=False,
|
|
target_shape=[],
|
|
target_video_length=81,
|
|
aspect_ratio="",
|
|
video_path=None,
|
|
sr_ratio=2.0,
|
|
)
|
|
|
|
# --- LoRA --------------------------------------------------------------
|
|
|
|
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
|
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
|
|
|
|
Dense has a single DIT (no MoE experts), so ``target`` must be
|
|
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
|
|
so ``switch_lora`` would crash — we use the dynamic-apply path that
|
|
merges LoRAs during GGUF dequant.
|
|
"""
|
|
if not specs:
|
|
return
|
|
|
|
resolved: list[tuple["LoRASpec", str]] = []
|
|
for spec in specs:
|
|
if spec.target != "both":
|
|
raise ValueError(
|
|
f"Dense 5B Turbo has a single DIT; LoRA target must be "
|
|
f"'both' (got {spec.target!r})."
|
|
)
|
|
local_path = self._resolve_lora_path(spec.path)
|
|
log.info(" LoRA %s → strength=%.2f (%s)",
|
|
spec.name or spec.path, spec.weight, local_path)
|
|
resolved.append((spec, local_path))
|
|
|
|
lora_cfgs = [
|
|
{"path": local_path, "strength": spec.weight}
|
|
for spec, local_path in resolved
|
|
]
|
|
self._runner.set_config({
|
|
"lora_configs": lora_cfgs,
|
|
"lora_dynamic_apply": True,
|
|
})
|
|
self._applied_loras = list(specs)
|
|
|
|
def unload_loras(self) -> None:
|
|
"""Remove all currently applied LoRAs."""
|
|
if not self._applied_loras:
|
|
return
|
|
self._runner.set_config({
|
|
"lora_configs": None,
|
|
"lora_dynamic_apply": False,
|
|
})
|
|
self._applied_loras = []
|
|
|
|
@staticmethod
|
|
def _resolve_lora_path(path: str) -> str:
|
|
"""Resolve a LoRA path. Supports:
|
|
- Absolute/relative local paths (returned as-is if the file exists)
|
|
- ``repo_id:filename`` HuggingFace references
|
|
"""
|
|
if os.path.isfile(path):
|
|
return path
|
|
if ":" in path and not path.startswith(("/", "./")):
|
|
repo_id, filename = path.split(":", 1)
|
|
from huggingface_hub import hf_hub_download
|
|
return hf_hub_download(repo_id=repo_id, filename=filename)
|
|
return path
|
|
|
|
# --- Inference ---------------------------------------------------------
|
|
|
|
def generate_i2v(
|
|
self,
|
|
image_path: str,
|
|
prompt: str,
|
|
seconds: int,
|
|
seed: int | None = None,
|
|
negative_prompt: str = "",
|
|
) -> np.ndarray:
|
|
"""Run image-to-video inference and return decoded frames.
|
|
|
|
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
|
|
"""
|
|
if seed is None:
|
|
seed = random.randint(0, 2**31 - 1)
|
|
|
|
# Wan2.2 target_video_length is "frames including the conditioning
|
|
# frame", so N seconds → N*fps + 1.
|
|
target_frames = seconds * self.fps + 1
|
|
|
|
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
|
out_path = tf.name
|
|
try:
|
|
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s",
|
|
prompt[:80], seconds, seed, out_path)
|
|
update_input_info_from_dict(
|
|
self._input_info_template,
|
|
{
|
|
"seed": seed,
|
|
"prompt": prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"image_path": image_path,
|
|
"save_result_path": out_path,
|
|
"target_video_length": target_frames,
|
|
"return_result_tensor": False,
|
|
},
|
|
)
|
|
self._runner.run_pipeline(self._input_info_template)
|
|
return _read_mp4_to_frames(out_path)
|
|
finally:
|
|
try:
|
|
os.remove(out_path)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
# --- MP4 decoding helper ------------------------------------------------------
|
|
|
|
def _read_mp4_to_frames(path: str) -> np.ndarray:
|
|
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
|
|
try:
|
|
import imageio.v3 as iio # type: ignore[import-not-found]
|
|
frames = iio.imread(path, plugin="pyav")
|
|
arr = np.asarray(frames)
|
|
if arr.ndim == 3:
|
|
arr = arr[None, ...]
|
|
return arr.astype(np.uint8)
|
|
except Exception as e: # pragma: no cover - fallback path
|
|
log.warning("imageio decode failed (%s); falling back to cv2", e)
|
|
import cv2 # type: ignore[import-not-found]
|
|
cap = cv2.VideoCapture(path)
|
|
frames: list[np.ndarray] = []
|
|
while True:
|
|
ok, frame = cap.read()
|
|
if not ok:
|
|
break
|
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
cap.release()
|
|
if not frames:
|
|
raise RuntimeError(f"Failed to decode any frames from {path}")
|
|
return np.stack(frames, axis=0).astype(np.uint8)
|