working ok
This commit is contained in:
+145
-187
@@ -1,4 +1,4 @@
|
||||
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
|
||||
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
@@ -22,15 +22,12 @@ Model weights are loaded once at construction and held resident across turns
|
||||
so reflective mode doesn't re-pay the load cost each reply.
|
||||
|
||||
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with
|
||||
quantised checkpoints from dit_repo.
|
||||
- dit_repo (configurable) — quantised DIT checkpoints. Supported
|
||||
formats:
|
||||
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
|
||||
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards are SKIPPED via
|
||||
ignore_patterns — replaced by the GGUF
|
||||
checkpoint from dit_repo.
|
||||
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
|
||||
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -49,27 +46,20 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# --- fp8 distill filenames --------------------------------------------------
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
|
||||
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
|
||||
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
|
||||
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
|
||||
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
|
||||
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
|
||||
|
||||
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
|
||||
# quantised files from dit_repo replace the DIT weights entirely. We must
|
||||
# keep the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
|
||||
# tokenizer support files. We only need the latter — the GGUF from dit_repo
|
||||
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
"high_noise_model/*.safetensors",
|
||||
"low_noise_model/*.safetensors",
|
||||
"*.pt",
|
||||
"diffusion_pytorch_model*.safetensors",
|
||||
"assets/*",
|
||||
"examples/*",
|
||||
"nohup.out",
|
||||
@@ -77,6 +67,40 @@ BASE_REPO_IGNORE_PATTERNS = [
|
||||
]
|
||||
|
||||
|
||||
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
|
||||
|
||||
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
|
||||
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
|
||||
and are missed by the structured sweep. This generic traversal catches
|
||||
them. Bounded depth + visited-set to avoid cycles.
|
||||
"""
|
||||
import torch
|
||||
if visited is None:
|
||||
visited = set()
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited or depth > 6:
|
||||
return 0
|
||||
visited.add(obj_id)
|
||||
n = 0
|
||||
for attr_name in dir(obj):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(obj, attr_name)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||
try:
|
||||
setattr(obj, attr_name, val.to(torch.float16))
|
||||
n += 1
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(val, "__dict__") and not callable(val):
|
||||
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||
return n
|
||||
|
||||
|
||||
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||
|
||||
@@ -132,13 +156,11 @@ def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
|
||||
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
|
||||
|
||||
Supports two DIT quantisation formats:
|
||||
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
|
||||
``lightx2v/Wan2.2-Distill-Models``)
|
||||
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
|
||||
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
|
||||
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
|
||||
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
|
||||
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
@@ -150,11 +172,11 @@ class Wan22Pipeline:
|
||||
base_repo: str,
|
||||
dit_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
model_cls: str = "wan2.2",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
dit_quant_scheme: str = "fp8-sgl",
|
||||
t5_quantized: bool = False,
|
||||
dit_quant_scheme: str = "gguf-Q8_0",
|
||||
t5_quantized: bool = True,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.dit_repo = dit_repo
|
||||
@@ -167,10 +189,15 @@ class Wan22Pipeline:
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
if not self._is_gguf:
|
||||
raise ValueError(
|
||||
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
|
||||
)
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
|
||||
self._dit_ckpt = self._ensure_dit_checkpoint(
|
||||
dit_repo, dit_quant_scheme,
|
||||
)
|
||||
self._t5_fp8_ckpt = (
|
||||
@@ -306,52 +333,53 @@ class Wan22Pipeline:
|
||||
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||
|
||||
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||
"""Cast leftover fp32 DIT pre/post weights to fp16.
|
||||
"""Cast leftover fp32 DIT weights to fp16 (dense model).
|
||||
|
||||
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
|
||||
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
|
||||
up loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
|
||||
runs uniformly in fp16.
|
||||
GGUF dequantises the transformer blocks to fp16, but a handful of
|
||||
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
|
||||
loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
|
||||
directly at ``runner.model`` (no MoE wrapper). After the structured
|
||||
pre/post weight sweep, we also run a recursive traversal to catch
|
||||
fp32 conv3d biases etc. that live outside pre/post_weight.
|
||||
"""
|
||||
import torch
|
||||
|
||||
runner = self._runner
|
||||
models = getattr(runner.model, "model", None)
|
||||
if models is None:
|
||||
return
|
||||
if not isinstance(models, (list, tuple)):
|
||||
models = [models]
|
||||
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
|
||||
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||
log.info(
|
||||
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
|
||||
n_struct, n_extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cast_fp32_dit_weights_in_model(m) -> int:
|
||||
import torch
|
||||
n_cast = 0
|
||||
for m in models:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
# Preserve pinned-memory status on pin_* tensors so
|
||||
# move_attr_to_cuda's non-blocking H2D copy is safe.
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
return n_cast
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@@ -374,46 +402,34 @@ class Wan22Pipeline:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_dit_checkpoints(
|
||||
def _ensure_dit_checkpoint(
|
||||
dit_repo: str,
|
||||
dit_quant_scheme: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the DIT pair.
|
||||
|
||||
Supports both fp8 safetensors and GGUF formats.
|
||||
"""
|
||||
) -> str:
|
||||
"""Return the local path to the single dense GGUF DIT checkpoint."""
|
||||
if not dit_repo:
|
||||
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||
if not dit_quant_scheme.startswith("gguf-"):
|
||||
raise ValueError(
|
||||
f"Only GGUF quant schemes are supported for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r})."
|
||||
)
|
||||
|
||||
is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
|
||||
|
||||
if is_gguf:
|
||||
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
|
||||
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
|
||||
else:
|
||||
high_file = FP8_HIGH_NOISE_FILE
|
||||
low_file = FP8_LOW_NOISE_FILE
|
||||
|
||||
# Local directory?
|
||||
if os.path.isdir(dit_repo):
|
||||
high = os.path.join(dit_repo, high_file)
|
||||
low = os.path.join(dit_repo, low_file)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
path = os.path.join(dit_repo, filename)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(
|
||||
f"DIT checkpoints not found in {dit_repo}: expected "
|
||||
f"{high_file} and {low_file}"
|
||||
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
|
||||
)
|
||||
return high, low
|
||||
return path
|
||||
|
||||
# HuggingFace download.
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading %s DIT checkpoints from %s ...",
|
||||
log.info("Downloading %s DIT checkpoint from %s ...",
|
||||
dit_quant_scheme, dit_repo)
|
||||
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
|
||||
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
|
||||
return high, low
|
||||
return hf_hub_download(repo_id=dit_repo, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_t5_fp8() -> str:
|
||||
@@ -431,8 +447,7 @@ class Wan22Pipeline:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._dit_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._dit_low
|
||||
cfg["dit_quantized_ckpt"] = self._dit_ckpt
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
# T5 fp8 quantization.
|
||||
@@ -496,103 +511,46 @@ class Wan22Pipeline:
|
||||
# --- LoRA --------------------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
|
||||
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
|
||||
|
||||
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
|
||||
to route the LoRA to the correct expert.
|
||||
|
||||
With ``lazy_load`` the DIT models are ``None`` at this point, so
|
||||
runtime ``switch_lora`` is impossible. Instead we inject
|
||||
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
|
||||
the LoRAs are applied when the models materialise on first inference.
|
||||
|
||||
Without ``lazy_load`` (models already resident) we call
|
||||
``switch_lora`` with explicit high/low keyword args.
|
||||
Dense has a single DIT (no MoE experts), so ``target`` must be
|
||||
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
|
||||
so ``switch_lora`` would crash — we use the dynamic-apply path that
|
||||
merges LoRAs during GGUF dequant.
|
||||
"""
|
||||
if not specs:
|
||||
return
|
||||
|
||||
# Resolve every path up-front (may trigger HF download).
|
||||
resolved: list[tuple["LoRASpec", str]] = []
|
||||
for spec in specs:
|
||||
if spec.target != "both":
|
||||
raise ValueError(
|
||||
f"Dense 5B Turbo has a single DIT; LoRA target must be "
|
||||
f"'both' (got {spec.target!r})."
|
||||
)
|
||||
local_path = self._resolve_lora_path(spec.path)
|
||||
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
|
||||
spec.name or spec.path, spec.weight, spec.target,
|
||||
local_path)
|
||||
log.info(" LoRA %s → strength=%.2f (%s)",
|
||||
spec.name or spec.path, spec.weight, local_path)
|
||||
resolved.append((spec, local_path))
|
||||
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
# Build the lora_configs list that LightX2V's lazy-load path
|
||||
# reads inside MultiDistillModelStruct.infer().
|
||||
lora_cfgs = []
|
||||
for spec, local_path in resolved:
|
||||
# LightX2V expects name "high_noise_model" / "low_noise_model"
|
||||
cfg_name = {
|
||||
"high_noise": "high_noise_model",
|
||||
"low_noise": "low_noise_model",
|
||||
}.get(spec.target)
|
||||
if cfg_name is None:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
lora_cfgs.append({
|
||||
"name": cfg_name,
|
||||
"path": local_path,
|
||||
"strength": spec.weight,
|
||||
})
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
else:
|
||||
# Models are loaded — use runtime hot-swap.
|
||||
high_path = high_strength = None
|
||||
low_path = low_strength = None
|
||||
for spec, local_path in resolved:
|
||||
if spec.target == "high_noise":
|
||||
high_path, high_strength = local_path, spec.weight
|
||||
elif spec.target == "low_noise":
|
||||
low_path, low_strength = local_path, spec.weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
|
||||
kwargs: dict = {}
|
||||
if high_path is not None:
|
||||
kwargs["high_lora_path"] = high_path
|
||||
kwargs["high_lora_strength"] = high_strength
|
||||
if low_path is not None:
|
||||
kwargs["low_lora_path"] = low_path
|
||||
kwargs["low_lora_strength"] = low_strength
|
||||
ok = self._runner.switch_lora(**kwargs)
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"runner.switch_lora returned False. Check that your "
|
||||
"LightX2V build supports runtime LoRA updates for "
|
||||
f"{self.model_cls}.")
|
||||
|
||||
lora_cfgs = [
|
||||
{"path": local_path, "strength": spec.weight}
|
||||
for spec, local_path in resolved
|
||||
]
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
self._applied_loras = list(specs)
|
||||
|
||||
def unload_loras(self) -> None:
|
||||
"""Remove all currently applied LoRAs."""
|
||||
if not self._applied_loras:
|
||||
return
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
# If models were materialised, drop them so the next inference
|
||||
# recreates them without LoRAs.
|
||||
model_struct = getattr(self._runner, "model", None)
|
||||
if model_struct is not None and hasattr(model_struct, "model"):
|
||||
for i in range(len(model_struct.model)):
|
||||
model_struct.model[i] = None
|
||||
else:
|
||||
self._runner.switch_lora("", 0.0)
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
self._applied_loras = []
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user