working ok

This commit is contained in:
2026-04-16 10:00:37 -04:00
parent 9debc56137
commit 129df7d1fa
24 changed files with 674 additions and 539 deletions
+145 -187
View File
@@ -1,4 +1,4 @@
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
@@ -22,15 +22,12 @@ Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with
quantised checkpoints from dit_repo.
- dit_repo (configurable) — quantised DIT checkpoints. Supported
formats:
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards are SKIPPED via
ignore_patterns — replaced by the GGUF
checkpoint from dit_repo.
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
"""
from __future__ import annotations
@@ -49,27 +46,20 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# --- fp8 distill filenames --------------------------------------------------
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
# quantised files from dit_repo replace the DIT weights entirely. We must
# keep the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/.
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
# tokenizer support files. We only need the latter — the GGUF from dit_repo
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
BASE_REPO_IGNORE_PATTERNS = [
"high_noise_model/*.safetensors",
"low_noise_model/*.safetensors",
"*.pt",
"diffusion_pytorch_model*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
@@ -77,6 +67,40 @@ BASE_REPO_IGNORE_PATTERNS = [
]
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
and are missed by the structured sweep. This generic traversal catches
them. Bounded depth + visited-set to avoid cycles.
"""
import torch
if visited is None:
visited = set()
obj_id = id(obj)
if obj_id in visited or depth > 6:
return 0
visited.add(obj_id)
n = 0
for attr_name in dir(obj):
if attr_name.startswith("__"):
continue
try:
val = getattr(obj, attr_name)
except Exception:
continue
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
try:
setattr(obj, attr_name, val.to(torch.float16))
n += 1
except Exception:
pass
elif hasattr(val, "__dict__") and not callable(val):
n += _cast_all_fp32_tensors(val, visited, depth + 1)
return n
def _patch_fp8_scaled_mm_for_blackwell() -> None:
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
@@ -132,13 +156,11 @@ def _patch_fp8_scaled_mm_for_blackwell() -> None:
class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
Supports two DIT quantisation formats:
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
``lightx2v/Wan2.2-Distill-Models``)
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
@@ -150,11 +172,11 @@ class Wan22Pipeline:
base_repo: str,
dit_repo: str,
config_json: str,
model_cls: str = "wan2.2_moe_distill",
model_cls: str = "wan2.2",
resolution: int = 480,
fps: int = 16,
dit_quant_scheme: str = "fp8-sgl",
t5_quantized: bool = False,
dit_quant_scheme: str = "gguf-Q8_0",
t5_quantized: bool = True,
):
self.base_repo = base_repo
self.dit_repo = dit_repo
@@ -167,10 +189,15 @@ class Wan22Pipeline:
self._applied_loras: list[LoRASpec] = []
self._is_gguf = dit_quant_scheme.startswith("gguf-")
if not self._is_gguf:
raise ValueError(
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
)
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
self._model_root = self._ensure_base_repo(base_repo)
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
self._dit_ckpt = self._ensure_dit_checkpoint(
dit_repo, dit_quant_scheme,
)
self._t5_fp8_ckpt = (
@@ -306,52 +333,53 @@ class Wan22Pipeline:
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
def _patch_dit_fp32_weights_for_gguf(self) -> None:
"""Cast leftover fp32 DIT pre/post weights to fp16.
"""Cast leftover fp32 DIT weights to fp16 (dense model).
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
up loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
runs uniformly in fp16.
GGUF dequantises the transformer blocks to fp16, but a handful of
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
directly at ``runner.model`` (no MoE wrapper). After the structured
pre/post weight sweep, we also run a recursive traversal to catch
fp32 conv3d biases etc. that live outside pre/post_weight.
"""
import torch
runner = self._runner
models = getattr(runner.model, "model", None)
if models is None:
return
if not isinstance(models, (list, tuple)):
models = [models]
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
n_extra = _cast_all_fp32_tensors(runner.model)
log.info(
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
n_struct, n_extra,
)
@staticmethod
def _cast_fp32_dit_weights_in_model(m) -> int:
import torch
n_cast = 0
for m in models:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
# Preserve pinned-memory status on pin_* tensors so
# move_attr_to_cuda's non-blocking H2D copy is safe.
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
return n_cast
# --- Weight provisioning -------------------------------------------------
@@ -374,46 +402,34 @@ class Wan22Pipeline:
)
@staticmethod
def _ensure_dit_checkpoints(
def _ensure_dit_checkpoint(
dit_repo: str,
dit_quant_scheme: str,
) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the DIT pair.
Supports both fp8 safetensors and GGUF formats.
"""
) -> str:
"""Return the local path to the single dense GGUF DIT checkpoint."""
if not dit_repo:
raise ValueError("dit_repo must be a HF repo id or local directory.")
if not dit_quant_scheme.startswith("gguf-"):
raise ValueError(
f"Only GGUF quant schemes are supported for dense 5B Turbo "
f"(got {dit_quant_scheme!r})."
)
is_gguf = dit_quant_scheme.startswith("gguf-")
quant = dit_quant_scheme.replace("gguf-", "")
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
if is_gguf:
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
quant = dit_quant_scheme.replace("gguf-", "")
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
else:
high_file = FP8_HIGH_NOISE_FILE
low_file = FP8_LOW_NOISE_FILE
# Local directory?
if os.path.isdir(dit_repo):
high = os.path.join(dit_repo, high_file)
low = os.path.join(dit_repo, low_file)
if not (os.path.isfile(high) and os.path.isfile(low)):
path = os.path.join(dit_repo, filename)
if not os.path.isfile(path):
raise FileNotFoundError(
f"DIT checkpoints not found in {dit_repo}: expected "
f"{high_file} and {low_file}"
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
)
return high, low
return path
# HuggingFace download.
from huggingface_hub import hf_hub_download
log.info("Downloading %s DIT checkpoints from %s ...",
log.info("Downloading %s DIT checkpoint from %s ...",
dit_quant_scheme, dit_repo)
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
return high, low
return hf_hub_download(repo_id=dit_repo, filename=filename)
@staticmethod
def _ensure_t5_fp8() -> str:
@@ -431,8 +447,7 @@ class Wan22Pipeline:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._dit_high
cfg["low_noise_quantized_ckpt"] = self._dit_low
cfg["dit_quantized_ckpt"] = self._dit_ckpt
cfg.setdefault("fps", self.fps)
# T5 fp8 quantization.
@@ -496,103 +511,46 @@ class Wan22Pipeline:
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
to route the LoRA to the correct expert.
With ``lazy_load`` the DIT models are ``None`` at this point, so
runtime ``switch_lora`` is impossible. Instead we inject
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
the LoRAs are applied when the models materialise on first inference.
Without ``lazy_load`` (models already resident) we call
``switch_lora`` with explicit high/low keyword args.
Dense has a single DIT (no MoE experts), so ``target`` must be
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
so ``switch_lora`` would crash — we use the dynamic-apply path that
merges LoRAs during GGUF dequant.
"""
if not specs:
return
# Resolve every path up-front (may trigger HF download).
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
if spec.target != "both":
raise ValueError(
f"Dense 5B Turbo has a single DIT; LoRA target must be "
f"'both' (got {spec.target!r})."
)
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
spec.name or spec.path, spec.weight, spec.target,
local_path)
log.info(" LoRA %s → strength=%.2f (%s)",
spec.name or spec.path, spec.weight, local_path)
resolved.append((spec, local_path))
lazy = self._config.get("lazy_load", False)
if lazy:
# Build the lora_configs list that LightX2V's lazy-load path
# reads inside MultiDistillModelStruct.infer().
lora_cfgs = []
for spec, local_path in resolved:
# LightX2V expects name "high_noise_model" / "low_noise_model"
cfg_name = {
"high_noise": "high_noise_model",
"low_noise": "low_noise_model",
}.get(spec.target)
if cfg_name is None:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
lora_cfgs.append({
"name": cfg_name,
"path": local_path,
"strength": spec.weight,
})
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
else:
# Models are loaded — use runtime hot-swap.
high_path = high_strength = None
low_path = low_strength = None
for spec, local_path in resolved:
if spec.target == "high_noise":
high_path, high_strength = local_path, spec.weight
elif spec.target == "low_noise":
low_path, low_strength = local_path, spec.weight
else:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
kwargs: dict = {}
if high_path is not None:
kwargs["high_lora_path"] = high_path
kwargs["high_lora_strength"] = high_strength
if low_path is not None:
kwargs["low_lora_path"] = low_path
kwargs["low_lora_strength"] = low_strength
ok = self._runner.switch_lora(**kwargs)
if not ok:
raise RuntimeError(
"runner.switch_lora returned False. Check that your "
"LightX2V build supports runtime LoRA updates for "
f"{self.model_cls}.")
lora_cfgs = [
{"path": local_path, "strength": spec.weight}
for spec, local_path in resolved
]
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
lazy = self._config.get("lazy_load", False)
if lazy:
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
# If models were materialised, drop them so the next inference
# recreates them without LoRAs.
model_struct = getattr(self._runner, "model", None)
if model_struct is not None and hasattr(model_struct, "model"):
for i in range(len(model_struct.model)):
model_struct.model[i] = None
else:
self._runner.switch_lora("", 0.0)
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
self._applied_loras = []
@staticmethod