working ok
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
# Agent guide — server/video_models/
|
||||
|
||||
Wrappers around 3rd-party video models. These are the trickiest files in the repo: LightX2V's internals move quickly upstream, and the Blackwell (RTX 5090 / SM120) GPU path requires several non-obvious patches layered on top. Read this before editing.
|
||||
|
||||
## Scope
|
||||
|
||||
- [wan22.py](wan22.py) — LightX2V Wan2.2-I2V A14B MoE pipeline. Supports fp8 safetensors and GGUF DIT checkpoints. Loaded once at startup, held resident; per-turn calls go through `generate_i2v` and `switch_lora`.
|
||||
- [musetalk.py](musetalk.py) — MuseTalk lip-sync over base frames + TTS audio.
|
||||
- [muxer.py](muxer.py) — thin ffmpeg wrappers: frames → MP4 loop, frames + audio → MP4.
|
||||
|
||||
Nothing here is imported unless `config.video.enabled` is true.
|
||||
|
||||
## LightX2V entry points (upstream API)
|
||||
|
||||
Use these symbols, not internal/private ones:
|
||||
|
||||
```python
|
||||
from lightx2v.utils.set_config import set_config
|
||||
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||
from lightx2v.infer import init_runner
|
||||
|
||||
config = set_config(args) # args is an argparse.Namespace
|
||||
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||
runner = init_runner(config) # loads all weights — ONCE
|
||||
|
||||
update_input_info_from_dict(input_info, {...}) # per-turn inputs
|
||||
runner.run_pipeline(input_info) # MP4 written to save_result_path
|
||||
runner.switch_lora(lora_path, strength) # hot-swap
|
||||
```
|
||||
|
||||
Keep model load out of the per-turn path — `init_runner` is expensive.
|
||||
|
||||
## Blackwell (SM120) patches — do not remove without testing
|
||||
|
||||
The GGUF pipeline works on a 5090 only because of layered patches in `wan22.py` and tuning in the LightX2V JSON configs under [configs/lightx2v/](../../configs/lightx2v/). Each patch exists because a stock upstream path segfaults or silently miscomputes on SM120.
|
||||
|
||||
**Dtype plumbing (GGUF path):**
|
||||
|
||||
- Default `DTYPE` must be `BF16` at `init_runner()` time — T5 offload buffers break if FP16 at init.
|
||||
- Flip `BF16 → FP16` *after* `init_runner()`.
|
||||
- Wrap T5 encoder so it runs under BF16 internally, then cast outputs `bf16 → fp16` before handing to the DIT. See `_patch_t5_dtype_for_gguf`.
|
||||
- Cast VAE **both** layers: the inner `.model` via `.to(fp16)` **and** the outer `WanVAE` wrapper's `mean` / `inv_std` / `scale` tensors. Missing the wrapper tensors upcasts the latent during decode's `z/inv_std + mean`.
|
||||
- DIT `pre_weight.patch_embedding.pin_weight` loads as fp32 (only `pin_bias` is fp16). Cast **and** re-pin via `.pin_memory()` — skipping re-pin segfaults during `to_cuda` H2D copy.
|
||||
- `sgl_kernel`'s fp8 scaled matmul is patched to `torch._scaled_mm` in `_patch_fp8_scaled_mm_for_blackwell`.
|
||||
|
||||
**LightX2V JSON config (see `wan22_i2v_gguf_distill.json`):**
|
||||
|
||||
- `modulate_type: "torch"` — Triton `fuse_scale_shift_kernel` segfaults in `ast_to_ttir` on Triton 3.4 + SM120.
|
||||
- `rope_type: "torch"` — flashinfer isn't installed.
|
||||
- `self_attn_1_type` / `cross_attn_*_type`: `"torch_sdpa"` — flash_attn3 unavailable; `sageattention==1.0.6` from PyPI segfaults on Blackwell (newer requires source build).
|
||||
|
||||
If you add a new quant scheme or a new model_cls, create its own JSON under `configs/lightx2v/` mirroring these choices, and exercise it end-to-end via a new `tests/component/test_NN_*.py` before wiring it into the default config.
|
||||
|
||||
## HF download layout
|
||||
|
||||
`Wan-AI/Wan2.2-I2V-A14B` ships ~28 GB of bf16 DIT shards we replace with the quantised `dit_repo`. `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](wan22.py) excludes them but **keeps** `high_noise_model/*.json` and `low_noise_model/*.json` — `set_config` parses architecture params (`dim`, etc.) from those. Don't broaden the ignore pattern without checking.
|
||||
|
||||
Supported quant schemes live in `wan22_dit_quant_scheme`:
|
||||
|
||||
- `fp8-sgl` — `lightx2v/Wan2.2-Distill-Models`, two `.safetensors` files
|
||||
- `gguf-Q4_K_M`, `gguf-Q8_0`, … — `QuantStack/Wan2.2-I2V-A14B-GGUF`, layout `HighNoise/…` and `LowNoise/…`
|
||||
|
||||
Filenames are templated at the top of `wan22.py`; update those if the upstream repos rename files.
|
||||
|
||||
## Testing
|
||||
|
||||
- `tests/component/test_02_wan22_loras.py` — full pipeline load + LoRA apply
|
||||
- `tests/component/test_09_gguf_generate.py` — GGUF end-to-end I2V
|
||||
- `tests/component/test_10_t5_encode.py` — T5 encoder dtype path
|
||||
- `tests/component/test_11_image_encode.py` — image → VAE latent
|
||||
- `tests/component/test_12_dit_single_step.py` — one DIT step per expert
|
||||
- `tests/component/test_13_vae_decode.py` — VAE decode → RGB
|
||||
|
||||
When diagnosing a Blackwell regression, run 10 → 11 → 12 → 13 in that order; the failure localises to the first failing stage.
|
||||
|
||||
## LoRAs
|
||||
|
||||
`switch_lora(path, strength)` applies; `switch_lora("", 0.0)` removes. `load_loras`/`unload_loras` in this wrapper iterate over `LoRASpec`s from config and call `switch_lora` per `target` sub-model (`high_noise`, `low_noise`, or `both`). Wrong target = silently wrong output.
|
||||
@@ -32,8 +32,13 @@ class MuseTalkEngine:
|
||||
def _load_impl(model_path: str):
|
||||
"""Load the MuseTalk inference implementation.
|
||||
|
||||
If none of the known entry points work the error message points at
|
||||
this file so you know where to fix it.
|
||||
Upstream MuseTalk has no library-style entry point — it's a bundle
|
||||
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
|
||||
``third_party/MuseTalk`` adds package metadata but the low-level
|
||||
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
|
||||
modules. We import them here to verify the install succeeded; the
|
||||
actual pipeline (VAE, UNet, Whisper, face detection, blending)
|
||||
is wired up inside ``MuseTalkEngine.lip_sync``.
|
||||
"""
|
||||
resolved = model_path
|
||||
if not os.path.isdir(model_path) and "/" in model_path:
|
||||
@@ -43,28 +48,19 @@ class MuseTalkEngine:
|
||||
except Exception as e: # pragma: no cover
|
||||
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
||||
|
||||
# Try upstream MuseTalk repo layout.
|
||||
try:
|
||||
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
|
||||
return MuseTalkInference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
|
||||
return MuseTalkInfer(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk import Inference # type: ignore[import-not-found]
|
||||
return Inference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
|
||||
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MuseTalk Python package is not importable. "
|
||||
"Check that third_party/MuseTalk was installed via "
|
||||
"`pip install /opt/MuseTalk` in the Dockerfile."
|
||||
) from e
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk is installed but no known Python entry point was found. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
|
||||
"to match the installed MuseTalk version."
|
||||
)
|
||||
# Return the resolved weight path; lip_sync loads models lazily on
|
||||
# first call so import-time failures don't block voice-only startup.
|
||||
return {"model_path": resolved, "loaded": False}
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
@@ -98,31 +94,22 @@ class MuseTalkEngine:
|
||||
if target_t > 0 and len(frames) != target_t:
|
||||
frames = _fit_frames_to_length(frames, target_t)
|
||||
|
||||
# The real MuseTalk call signature varies. Most common is a method
|
||||
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
|
||||
for method_name in ("run", "infer", "lip_sync", "__call__"):
|
||||
method = getattr(self._infer, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
try:
|
||||
result = method(
|
||||
frames=frames,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
fps=fps,
|
||||
)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
# Try positional
|
||||
try:
|
||||
result = method(frames, audio, sample_rate, fps)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk wrapper could not find a working inference method. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
|
||||
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
|
||||
# realtime_inference.py::Avatar.inference) needs:
|
||||
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
|
||||
# - face_alignment (bbox)
|
||||
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
|
||||
# - Whisper encoder (openai/whisper-tiny)
|
||||
# - face_parsing weights
|
||||
# Plus its preprocessing module has import-time side effects that
|
||||
# load dwpose weights from a CWD-relative path. Turn the full
|
||||
# pipeline on by extending this method once those deps are
|
||||
# installed and weights are resolved — until then, callers should
|
||||
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
|
||||
# will skip the lip-sync pass.
|
||||
raise NotImplementedError(
|
||||
"MuseTalk lip-sync pipeline is not wired up yet. "
|
||||
"Set config.video.musetalk.enabled=false to bypass."
|
||||
)
|
||||
|
||||
|
||||
|
||||
+145
-187
@@ -1,4 +1,4 @@
|
||||
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
|
||||
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
@@ -22,15 +22,12 @@ Model weights are loaded once at construction and held resident across turns
|
||||
so reflective mode doesn't re-pay the load cost each reply.
|
||||
|
||||
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with
|
||||
quantised checkpoints from dit_repo.
|
||||
- dit_repo (configurable) — quantised DIT checkpoints. Supported
|
||||
formats:
|
||||
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
|
||||
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards are SKIPPED via
|
||||
ignore_patterns — replaced by the GGUF
|
||||
checkpoint from dit_repo.
|
||||
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
|
||||
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -49,27 +46,20 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# --- fp8 distill filenames --------------------------------------------------
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
|
||||
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
|
||||
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
|
||||
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
|
||||
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
|
||||
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
|
||||
|
||||
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
|
||||
# quantised files from dit_repo replace the DIT weights entirely. We must
|
||||
# keep the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
|
||||
# tokenizer support files. We only need the latter — the GGUF from dit_repo
|
||||
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
"high_noise_model/*.safetensors",
|
||||
"low_noise_model/*.safetensors",
|
||||
"*.pt",
|
||||
"diffusion_pytorch_model*.safetensors",
|
||||
"assets/*",
|
||||
"examples/*",
|
||||
"nohup.out",
|
||||
@@ -77,6 +67,40 @@ BASE_REPO_IGNORE_PATTERNS = [
|
||||
]
|
||||
|
||||
|
||||
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
|
||||
|
||||
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
|
||||
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
|
||||
and are missed by the structured sweep. This generic traversal catches
|
||||
them. Bounded depth + visited-set to avoid cycles.
|
||||
"""
|
||||
import torch
|
||||
if visited is None:
|
||||
visited = set()
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited or depth > 6:
|
||||
return 0
|
||||
visited.add(obj_id)
|
||||
n = 0
|
||||
for attr_name in dir(obj):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(obj, attr_name)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||
try:
|
||||
setattr(obj, attr_name, val.to(torch.float16))
|
||||
n += 1
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(val, "__dict__") and not callable(val):
|
||||
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||
return n
|
||||
|
||||
|
||||
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||
|
||||
@@ -132,13 +156,11 @@ def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
|
||||
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
|
||||
|
||||
Supports two DIT quantisation formats:
|
||||
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
|
||||
``lightx2v/Wan2.2-Distill-Models``)
|
||||
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
|
||||
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
|
||||
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
|
||||
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
|
||||
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
@@ -150,11 +172,11 @@ class Wan22Pipeline:
|
||||
base_repo: str,
|
||||
dit_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
model_cls: str = "wan2.2",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
dit_quant_scheme: str = "fp8-sgl",
|
||||
t5_quantized: bool = False,
|
||||
dit_quant_scheme: str = "gguf-Q8_0",
|
||||
t5_quantized: bool = True,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.dit_repo = dit_repo
|
||||
@@ -167,10 +189,15 @@ class Wan22Pipeline:
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
if not self._is_gguf:
|
||||
raise ValueError(
|
||||
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
|
||||
)
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
|
||||
self._dit_ckpt = self._ensure_dit_checkpoint(
|
||||
dit_repo, dit_quant_scheme,
|
||||
)
|
||||
self._t5_fp8_ckpt = (
|
||||
@@ -306,52 +333,53 @@ class Wan22Pipeline:
|
||||
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||
|
||||
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||
"""Cast leftover fp32 DIT pre/post weights to fp16.
|
||||
"""Cast leftover fp32 DIT weights to fp16 (dense model).
|
||||
|
||||
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
|
||||
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
|
||||
up loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
|
||||
runs uniformly in fp16.
|
||||
GGUF dequantises the transformer blocks to fp16, but a handful of
|
||||
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
|
||||
loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
|
||||
directly at ``runner.model`` (no MoE wrapper). After the structured
|
||||
pre/post weight sweep, we also run a recursive traversal to catch
|
||||
fp32 conv3d biases etc. that live outside pre/post_weight.
|
||||
"""
|
||||
import torch
|
||||
|
||||
runner = self._runner
|
||||
models = getattr(runner.model, "model", None)
|
||||
if models is None:
|
||||
return
|
||||
if not isinstance(models, (list, tuple)):
|
||||
models = [models]
|
||||
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
|
||||
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||
log.info(
|
||||
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
|
||||
n_struct, n_extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cast_fp32_dit_weights_in_model(m) -> int:
|
||||
import torch
|
||||
n_cast = 0
|
||||
for m in models:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
# Preserve pinned-memory status on pin_* tensors so
|
||||
# move_attr_to_cuda's non-blocking H2D copy is safe.
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
return n_cast
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@@ -374,46 +402,34 @@ class Wan22Pipeline:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_dit_checkpoints(
|
||||
def _ensure_dit_checkpoint(
|
||||
dit_repo: str,
|
||||
dit_quant_scheme: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the DIT pair.
|
||||
|
||||
Supports both fp8 safetensors and GGUF formats.
|
||||
"""
|
||||
) -> str:
|
||||
"""Return the local path to the single dense GGUF DIT checkpoint."""
|
||||
if not dit_repo:
|
||||
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||
if not dit_quant_scheme.startswith("gguf-"):
|
||||
raise ValueError(
|
||||
f"Only GGUF quant schemes are supported for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r})."
|
||||
)
|
||||
|
||||
is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
|
||||
|
||||
if is_gguf:
|
||||
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
|
||||
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
|
||||
else:
|
||||
high_file = FP8_HIGH_NOISE_FILE
|
||||
low_file = FP8_LOW_NOISE_FILE
|
||||
|
||||
# Local directory?
|
||||
if os.path.isdir(dit_repo):
|
||||
high = os.path.join(dit_repo, high_file)
|
||||
low = os.path.join(dit_repo, low_file)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
path = os.path.join(dit_repo, filename)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(
|
||||
f"DIT checkpoints not found in {dit_repo}: expected "
|
||||
f"{high_file} and {low_file}"
|
||||
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
|
||||
)
|
||||
return high, low
|
||||
return path
|
||||
|
||||
# HuggingFace download.
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading %s DIT checkpoints from %s ...",
|
||||
log.info("Downloading %s DIT checkpoint from %s ...",
|
||||
dit_quant_scheme, dit_repo)
|
||||
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
|
||||
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
|
||||
return high, low
|
||||
return hf_hub_download(repo_id=dit_repo, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_t5_fp8() -> str:
|
||||
@@ -431,8 +447,7 @@ class Wan22Pipeline:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._dit_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._dit_low
|
||||
cfg["dit_quantized_ckpt"] = self._dit_ckpt
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
# T5 fp8 quantization.
|
||||
@@ -496,103 +511,46 @@ class Wan22Pipeline:
|
||||
# --- LoRA --------------------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
|
||||
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
|
||||
|
||||
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
|
||||
to route the LoRA to the correct expert.
|
||||
|
||||
With ``lazy_load`` the DIT models are ``None`` at this point, so
|
||||
runtime ``switch_lora`` is impossible. Instead we inject
|
||||
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
|
||||
the LoRAs are applied when the models materialise on first inference.
|
||||
|
||||
Without ``lazy_load`` (models already resident) we call
|
||||
``switch_lora`` with explicit high/low keyword args.
|
||||
Dense has a single DIT (no MoE experts), so ``target`` must be
|
||||
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
|
||||
so ``switch_lora`` would crash — we use the dynamic-apply path that
|
||||
merges LoRAs during GGUF dequant.
|
||||
"""
|
||||
if not specs:
|
||||
return
|
||||
|
||||
# Resolve every path up-front (may trigger HF download).
|
||||
resolved: list[tuple["LoRASpec", str]] = []
|
||||
for spec in specs:
|
||||
if spec.target != "both":
|
||||
raise ValueError(
|
||||
f"Dense 5B Turbo has a single DIT; LoRA target must be "
|
||||
f"'both' (got {spec.target!r})."
|
||||
)
|
||||
local_path = self._resolve_lora_path(spec.path)
|
||||
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
|
||||
spec.name or spec.path, spec.weight, spec.target,
|
||||
local_path)
|
||||
log.info(" LoRA %s → strength=%.2f (%s)",
|
||||
spec.name or spec.path, spec.weight, local_path)
|
||||
resolved.append((spec, local_path))
|
||||
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
# Build the lora_configs list that LightX2V's lazy-load path
|
||||
# reads inside MultiDistillModelStruct.infer().
|
||||
lora_cfgs = []
|
||||
for spec, local_path in resolved:
|
||||
# LightX2V expects name "high_noise_model" / "low_noise_model"
|
||||
cfg_name = {
|
||||
"high_noise": "high_noise_model",
|
||||
"low_noise": "low_noise_model",
|
||||
}.get(spec.target)
|
||||
if cfg_name is None:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
lora_cfgs.append({
|
||||
"name": cfg_name,
|
||||
"path": local_path,
|
||||
"strength": spec.weight,
|
||||
})
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
else:
|
||||
# Models are loaded — use runtime hot-swap.
|
||||
high_path = high_strength = None
|
||||
low_path = low_strength = None
|
||||
for spec, local_path in resolved:
|
||||
if spec.target == "high_noise":
|
||||
high_path, high_strength = local_path, spec.weight
|
||||
elif spec.target == "low_noise":
|
||||
low_path, low_strength = local_path, spec.weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
|
||||
kwargs: dict = {}
|
||||
if high_path is not None:
|
||||
kwargs["high_lora_path"] = high_path
|
||||
kwargs["high_lora_strength"] = high_strength
|
||||
if low_path is not None:
|
||||
kwargs["low_lora_path"] = low_path
|
||||
kwargs["low_lora_strength"] = low_strength
|
||||
ok = self._runner.switch_lora(**kwargs)
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"runner.switch_lora returned False. Check that your "
|
||||
"LightX2V build supports runtime LoRA updates for "
|
||||
f"{self.model_cls}.")
|
||||
|
||||
lora_cfgs = [
|
||||
{"path": local_path, "strength": spec.weight}
|
||||
for spec, local_path in resolved
|
||||
]
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
self._applied_loras = list(specs)
|
||||
|
||||
def unload_loras(self) -> None:
|
||||
"""Remove all currently applied LoRAs."""
|
||||
if not self._applied_loras:
|
||||
return
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
# If models were materialised, drop them so the next inference
|
||||
# recreates them without LoRAs.
|
||||
model_struct = getattr(self._runner, "model", None)
|
||||
if model_struct is not None and hasattr(model_struct, "model"):
|
||||
for i in range(len(model_struct.model)):
|
||||
model_struct.model[i] = None
|
||||
else:
|
||||
self._runner.switch_lora("", 0.0)
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
self._applied_loras = []
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user