working ok
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
# Agent guide — server/
|
||||
|
||||
The audio pipeline and FastAPI surface. The video stack lives under [video_models/](video_models/) and has its own guide.
|
||||
|
||||
## Module map
|
||||
|
||||
- [main.py](main.py) — FastAPI app, lifespan, `/ws/chat` WebSocket, video/avatar HTTP endpoints. Keep business logic out of here; this is a transport layer.
|
||||
- [models.py](models.py) — `ModelManager.load_all()`. All models are loaded once at startup in a fixed order: VAD → ASR → LLM → TTS → (optional) Video. `video_engine` stays `None` when `config.video.enabled` is false — callers MUST tolerate that.
|
||||
- [config.py](config.py) — thin YAML loader, exposes a single `config` dict. Don't scatter `yaml.safe_load` elsewhere.
|
||||
- [pipeline.py](pipeline.py) — `ConversationSession`, one instance per WebSocket. Owns per-session state (VAD stream, conversation history, KV cache, cancel event). Orchestrates VAD → ASR → LLM → TTS and optionally the video branch.
|
||||
- [vad.py](vad.py) — Silero VAD via ONNX Runtime on CPU. `StreamingVAD.process_chunk(pcm_16k) → utterance | None`. Returns a full utterance only on speech→silence transition.
|
||||
- [asr.py](asr.py) — Qwen3-ASR wrapper. Sync, called under `asyncio.to_thread`.
|
||||
- [llm.py](llm.py) — two backends behind a common `generate(history, max_new_tokens, kv_cache_state) → (text, KVCacheState)` signature: `LLMEngine` (local transformers) and `LMStudioEngine` (HTTP, no KV cache).
|
||||
- [tts.py](tts.py) — Kokoro wrapper. The per-segment generator yields `(graphemes, _ps, audio_f32)` tuples at 24 kHz.
|
||||
- [audio_utils.py](audio_utils.py) — `pcm_bytes_to_float32` / `float32_to_pcm_bytes`. The WebSocket protocol is 16 kHz int16 PCM in, 24 kHz int16 PCM out.
|
||||
- [video.py](video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine`. Orchestrates Wan2.2 + MuseTalk. Gated by `config.video.enabled`.
|
||||
- [video_models/](video_models/) — see [video_models/AGENT.md](video_models/AGENT.md) for Blackwell/GGUF gotchas.
|
||||
|
||||
## Session lifecycle (pipeline.py)
|
||||
|
||||
1. Client connects → `ConversationSession(models, send_json, send_bytes)` → `start()` emits `{type: "status", state: "listening"}` and a `{type: "video_mode", ...}` hint.
|
||||
2. Inbound binary frames are 16 kHz int16 PCM → `handle_audio_chunk` → VAD. On speech→silence the session kicks off `_process_utterance` as an `asyncio.Task` so it never blocks the WebSocket receive loop.
|
||||
3. `_process_utterance` flow: ASR → LLM → TTS stream. Each blocking call wraps in `asyncio.to_thread`.
|
||||
4. TTS output is split into short segments via `_split_into_segments` before synthesis so streaming chunks stay small.
|
||||
5. TTS runs on a background `threading.Thread`, feeding a `queue.Queue` that the async loop drains.
|
||||
|
||||
## Barge-in
|
||||
|
||||
- `cancel_event: threading.Event` is the single stop signal. Checked between every stage and inside the TTS queue drain loop.
|
||||
- Two ways to trigger: new VAD utterance while `is_responding` (`handle_audio_chunk`), or an explicit `{type: "interrupt", last_chunk_id}` text message (`interrupt`).
|
||||
- On cancel: set the event, send `{type: "interrupt"}` to the client so it flushes its audio buffer, and discard any pending video clip.
|
||||
- Don't add work after `cancel_event.is_set()` without checking — that's how zombie audio/video reaches the client after a barge-in.
|
||||
|
||||
## Video branch
|
||||
|
||||
The audio pipeline changes shape when `video_engine.is_ready()`:
|
||||
|
||||
- PCM chunks and `response_text` are **not** streamed during the turn — they're buffered.
|
||||
- TTS audio is concatenated into one float32 array.
|
||||
- After TTS completes, `video_engine.generate_speaking_clip(audio, sr, reply_text)` renders the MP4 (blocking, wrapped in `to_thread`).
|
||||
- The full clip + final text is sent as a single `speaking_clip` message.
|
||||
|
||||
If you extend the audio pipeline, preserve this dual-mode behaviour: the client's UX is very different between the two paths, and mixing them (e.g. sending PCM *and* a clip) will double-play the audio.
|
||||
|
||||
## Conventions
|
||||
|
||||
- Dataclasses for structured config (see `VideoConfig`, `LoRASpec`). Parse once in `*.from_dict`; don't re-read `config.yml` mid-session.
|
||||
- `asyncio.to_thread` for any sync model call from an async context. Never call `.generate()` / `.transcribe()` / `.pipeline()` directly on the event loop.
|
||||
- Locks: `VideoEngine._lock` serialises model state mutations; `ConversationSession` is not locked because each WebSocket gets its own instance.
|
||||
- Logging: one `log = logging.getLogger(__name__)` per module. INFO for lifecycle and per-turn milestones; avoid DEBUG spam in hot loops.
|
||||
- Keep `main.py` thin. New endpoints should delegate to a method on `ModelManager` or an engine class.
|
||||
|
||||
## Testing
|
||||
|
||||
- `tests/unit/test_pipeline_video_branch.py` — the video vs. audio path selection. Update it if you change the `use_video` condition.
|
||||
- `tests/unit/test_video_config.py` / `test_video_engine_logic.py` — config parsing and the pure logic in `video.py`.
|
||||
- Component tests live in `tests/component/` and require the Docker GPU environment. See [tests/README.md](../tests/README.md).
|
||||
+1
-1
@@ -133,7 +133,7 @@ async def reload_loras(body: dict):
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
if target != "both":
|
||||
target = "both"
|
||||
specs.append(
|
||||
LoRASpec(
|
||||
|
||||
+1
-2
@@ -121,8 +121,7 @@ class ModelManager:
|
||||
log.info("Loading avatar video engine...")
|
||||
cfg = VideoConfig.from_dict(video_cfg_raw)
|
||||
self.video_engine = VideoEngine(cfg)
|
||||
if cfg.loras:
|
||||
self.video_engine.load_loras(cfg.loras)
|
||||
self.video_engine.load_models()
|
||||
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
|
||||
|
||||
def create_vad(self) -> StreamingVAD:
|
||||
|
||||
+49
-37
@@ -18,18 +18,18 @@ import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LoRATarget = Literal["high_noise", "low_noise", "both"]
|
||||
LoRATarget = Literal["both"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRASpec:
|
||||
"""One LoRA adapter entry from ``config.video.loras``.
|
||||
|
||||
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
|
||||
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
|
||||
sub-model) and must be applied to the correct target. Allow
|
||||
``target="both"`` for LoRAs that should be applied to both sub-models
|
||||
(e.g. style LoRAs).
|
||||
The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE
|
||||
experts), so ``target`` is always ``"both"``. The field is kept for
|
||||
forward compatibility and config-file compatibility with older MoE
|
||||
configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced
|
||||
to ``"both"`` in ``VideoConfig.from_dict``.
|
||||
"""
|
||||
|
||||
path: str
|
||||
@@ -60,18 +60,20 @@ class VideoConfig:
|
||||
# Model paths — can be overridden via config.yml.video.models.
|
||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||
# The bf16 DIT shards in this repo are skipped — we
|
||||
# replace them with quantised files from wan22_dit_repo.
|
||||
# wan22_dit_repo : HF repo id (or local dir) providing the quantised
|
||||
# DIT checkpoints (fp8 or GGUF).
|
||||
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
|
||||
# replace them with a quantised GGUF from wan22_dit_repo.
|
||||
# wan22_dit_repo : HF repo id (or local dir) providing the single
|
||||
# dense GGUF DIT checkpoint (5B Turbo).
|
||||
# wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default)
|
||||
# or "gguf-Q4_K_M" for lower VRAM.
|
||||
# wan22_config_json : path to the LightX2V inference config template the
|
||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
||||
wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||
wan22_dit_quant_scheme: str = "fp8-sgl"
|
||||
wan22_t5_quantized: bool = False
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
wan22_model_cls: str = "wan2.2_moe_distill"
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
|
||||
wan22_dit_quant_scheme: str = "gguf-Q8_0"
|
||||
wan22_t5_quantized: bool = True
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
|
||||
wan22_model_cls: str = "wan2.2"
|
||||
musetalk_enabled: bool = True
|
||||
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||
|
||||
@classmethod
|
||||
@@ -92,9 +94,10 @@ class VideoConfig:
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
if target != "both":
|
||||
log.warning(
|
||||
"LoRA %s: invalid target %r, defaulting to 'both'",
|
||||
"LoRA %s: target %r is MoE-era; coercing to 'both' "
|
||||
"(dense 5B has a single DIT).",
|
||||
entry.get("path"), target,
|
||||
)
|
||||
target = "both"
|
||||
@@ -122,30 +125,32 @@ class VideoConfig:
|
||||
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
|
||||
loras=loras,
|
||||
wan22_base_repo=str(
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B")
|
||||
),
|
||||
wan22_dit_repo=str(
|
||||
models_raw.get(
|
||||
"wan22_dit_repo",
|
||||
# Backwards compat: fall back to old key name.
|
||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
|
||||
"hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF",
|
||||
)
|
||||
),
|
||||
wan22_dit_quant_scheme=str(
|
||||
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
|
||||
models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0")
|
||||
),
|
||||
wan22_t5_quantized=bool(
|
||||
models_raw.get("wan22_t5_quantized", False)
|
||||
models_raw.get("wan22_t5_quantized", True)
|
||||
),
|
||||
wan22_config_json=str(
|
||||
models_raw.get(
|
||||
"wan22_config_json",
|
||||
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
|
||||
"/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json",
|
||||
)
|
||||
),
|
||||
wan22_model_cls=str(
|
||||
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
|
||||
models_raw.get("wan22_model_cls", "wan2.2")
|
||||
),
|
||||
musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True))
|
||||
if isinstance(raw.get("musetalk"), dict)
|
||||
else bool(raw.get("musetalk_enabled", True)),
|
||||
musetalk_model_path=str(
|
||||
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
|
||||
),
|
||||
@@ -235,17 +240,22 @@ class VideoEngine:
|
||||
self._wan22.load_loras(self.cfg.loras)
|
||||
log.info("Wan2.2 pipeline ready.")
|
||||
|
||||
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||
log.info("MuseTalk engine ready.")
|
||||
if self.cfg.musetalk_enabled:
|
||||
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||
log.info("MuseTalk engine ready.")
|
||||
else:
|
||||
log.info("MuseTalk disabled via config — skipping lip-sync pass.")
|
||||
self._musetalk = None
|
||||
|
||||
# --- Readiness ------------------------------------------------------
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True when an avatar is set and a speaking clip can be produced."""
|
||||
musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None
|
||||
return (
|
||||
self._wan22 is not None
|
||||
and self._musetalk is not None
|
||||
and musetalk_ok
|
||||
and self.avatar_path is not None
|
||||
and self.idle_clip_mp4 is not None
|
||||
)
|
||||
@@ -336,7 +346,6 @@ class VideoEngine:
|
||||
"(avatar set? models loaded?)"
|
||||
)
|
||||
assert self._wan22 is not None
|
||||
assert self._musetalk is not None
|
||||
|
||||
# 1. Source base frames.
|
||||
if self.cfg.mode == "library":
|
||||
@@ -351,13 +360,16 @@ class VideoEngine:
|
||||
seed=None, # random each turn
|
||||
)
|
||||
|
||||
# 2. Lip-sync the base frames to the given audio.
|
||||
synced_frames = self._musetalk.lip_sync(
|
||||
frames=base_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
# 2. Lip-sync the base frames to the given audio (if enabled).
|
||||
if self._musetalk is not None:
|
||||
synced_frames = self._musetalk.lip_sync(
|
||||
frames=base_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
else:
|
||||
synced_frames = base_frames
|
||||
|
||||
# 3. Mux frames + audio into an MP4.
|
||||
from server.video_models.muxer import frames_and_audio_to_mp4
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Agent guide — server/video_models/
|
||||
|
||||
Wrappers around 3rd-party video models. These are the trickiest files in the repo: LightX2V's internals move quickly upstream, and the Blackwell (RTX 5090 / SM120) GPU path requires several non-obvious patches layered on top. Read this before editing.
|
||||
|
||||
## Scope
|
||||
|
||||
- [wan22.py](wan22.py) — LightX2V Wan2.2-I2V A14B MoE pipeline. Supports fp8 safetensors and GGUF DIT checkpoints. Loaded once at startup, held resident; per-turn calls go through `generate_i2v` and `switch_lora`.
|
||||
- [musetalk.py](musetalk.py) — MuseTalk lip-sync over base frames + TTS audio.
|
||||
- [muxer.py](muxer.py) — thin ffmpeg wrappers: frames → MP4 loop, frames + audio → MP4.
|
||||
|
||||
Nothing here is imported unless `config.video.enabled` is true.
|
||||
|
||||
## LightX2V entry points (upstream API)
|
||||
|
||||
Use these symbols, not internal/private ones:
|
||||
|
||||
```python
|
||||
from lightx2v.utils.set_config import set_config
|
||||
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||
from lightx2v.infer import init_runner
|
||||
|
||||
config = set_config(args) # args is an argparse.Namespace
|
||||
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||
runner = init_runner(config) # loads all weights — ONCE
|
||||
|
||||
update_input_info_from_dict(input_info, {...}) # per-turn inputs
|
||||
runner.run_pipeline(input_info) # MP4 written to save_result_path
|
||||
runner.switch_lora(lora_path, strength) # hot-swap
|
||||
```
|
||||
|
||||
Keep model load out of the per-turn path — `init_runner` is expensive.
|
||||
|
||||
## Blackwell (SM120) patches — do not remove without testing
|
||||
|
||||
The GGUF pipeline works on a 5090 only because of layered patches in `wan22.py` and tuning in the LightX2V JSON configs under [configs/lightx2v/](../../configs/lightx2v/). Each patch exists because a stock upstream path segfaults or silently miscomputes on SM120.
|
||||
|
||||
**Dtype plumbing (GGUF path):**
|
||||
|
||||
- Default `DTYPE` must be `BF16` at `init_runner()` time — T5 offload buffers break if FP16 at init.
|
||||
- Flip `BF16 → FP16` *after* `init_runner()`.
|
||||
- Wrap T5 encoder so it runs under BF16 internally, then cast outputs `bf16 → fp16` before handing to the DIT. See `_patch_t5_dtype_for_gguf`.
|
||||
- Cast VAE **both** layers: the inner `.model` via `.to(fp16)` **and** the outer `WanVAE` wrapper's `mean` / `inv_std` / `scale` tensors. Missing the wrapper tensors upcasts the latent during decode's `z/inv_std + mean`.
|
||||
- DIT `pre_weight.patch_embedding.pin_weight` loads as fp32 (only `pin_bias` is fp16). Cast **and** re-pin via `.pin_memory()` — skipping re-pin segfaults during `to_cuda` H2D copy.
|
||||
- `sgl_kernel`'s fp8 scaled matmul is patched to `torch._scaled_mm` in `_patch_fp8_scaled_mm_for_blackwell`.
|
||||
|
||||
**LightX2V JSON config (see `wan22_i2v_gguf_distill.json`):**
|
||||
|
||||
- `modulate_type: "torch"` — Triton `fuse_scale_shift_kernel` segfaults in `ast_to_ttir` on Triton 3.4 + SM120.
|
||||
- `rope_type: "torch"` — flashinfer isn't installed.
|
||||
- `self_attn_1_type` / `cross_attn_*_type`: `"torch_sdpa"` — flash_attn3 unavailable; `sageattention==1.0.6` from PyPI segfaults on Blackwell (newer requires source build).
|
||||
|
||||
If you add a new quant scheme or a new model_cls, create its own JSON under `configs/lightx2v/` mirroring these choices, and exercise it end-to-end via a new `tests/component/test_NN_*.py` before wiring it into the default config.
|
||||
|
||||
## HF download layout
|
||||
|
||||
`Wan-AI/Wan2.2-I2V-A14B` ships ~28 GB of bf16 DIT shards we replace with the quantised `dit_repo`. `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](wan22.py) excludes them but **keeps** `high_noise_model/*.json` and `low_noise_model/*.json` — `set_config` parses architecture params (`dim`, etc.) from those. Don't broaden the ignore pattern without checking.
|
||||
|
||||
Supported quant schemes live in `wan22_dit_quant_scheme`:
|
||||
|
||||
- `fp8-sgl` — `lightx2v/Wan2.2-Distill-Models`, two `.safetensors` files
|
||||
- `gguf-Q4_K_M`, `gguf-Q8_0`, … — `QuantStack/Wan2.2-I2V-A14B-GGUF`, layout `HighNoise/…` and `LowNoise/…`
|
||||
|
||||
Filenames are templated at the top of `wan22.py`; update those if the upstream repos rename files.
|
||||
|
||||
## Testing
|
||||
|
||||
- `tests/component/test_02_wan22_loras.py` — full pipeline load + LoRA apply
|
||||
- `tests/component/test_09_gguf_generate.py` — GGUF end-to-end I2V
|
||||
- `tests/component/test_10_t5_encode.py` — T5 encoder dtype path
|
||||
- `tests/component/test_11_image_encode.py` — image → VAE latent
|
||||
- `tests/component/test_12_dit_single_step.py` — one DIT step per expert
|
||||
- `tests/component/test_13_vae_decode.py` — VAE decode → RGB
|
||||
|
||||
When diagnosing a Blackwell regression, run 10 → 11 → 12 → 13 in that order; the failure localises to the first failing stage.
|
||||
|
||||
## LoRAs
|
||||
|
||||
`switch_lora(path, strength)` applies; `switch_lora("", 0.0)` removes. `load_loras`/`unload_loras` in this wrapper iterate over `LoRASpec`s from config and call `switch_lora` per `target` sub-model (`high_noise`, `low_noise`, or `both`). Wrong target = silently wrong output.
|
||||
@@ -32,8 +32,13 @@ class MuseTalkEngine:
|
||||
def _load_impl(model_path: str):
|
||||
"""Load the MuseTalk inference implementation.
|
||||
|
||||
If none of the known entry points work the error message points at
|
||||
this file so you know where to fix it.
|
||||
Upstream MuseTalk has no library-style entry point — it's a bundle
|
||||
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
|
||||
``third_party/MuseTalk`` adds package metadata but the low-level
|
||||
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
|
||||
modules. We import them here to verify the install succeeded; the
|
||||
actual pipeline (VAE, UNet, Whisper, face detection, blending)
|
||||
is wired up inside ``MuseTalkEngine.lip_sync``.
|
||||
"""
|
||||
resolved = model_path
|
||||
if not os.path.isdir(model_path) and "/" in model_path:
|
||||
@@ -43,28 +48,19 @@ class MuseTalkEngine:
|
||||
except Exception as e: # pragma: no cover
|
||||
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
||||
|
||||
# Try upstream MuseTalk repo layout.
|
||||
try:
|
||||
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
|
||||
return MuseTalkInference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
|
||||
return MuseTalkInfer(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk import Inference # type: ignore[import-not-found]
|
||||
return Inference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
|
||||
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MuseTalk Python package is not importable. "
|
||||
"Check that third_party/MuseTalk was installed via "
|
||||
"`pip install /opt/MuseTalk` in the Dockerfile."
|
||||
) from e
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk is installed but no known Python entry point was found. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
|
||||
"to match the installed MuseTalk version."
|
||||
)
|
||||
# Return the resolved weight path; lip_sync loads models lazily on
|
||||
# first call so import-time failures don't block voice-only startup.
|
||||
return {"model_path": resolved, "loaded": False}
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
@@ -98,31 +94,22 @@ class MuseTalkEngine:
|
||||
if target_t > 0 and len(frames) != target_t:
|
||||
frames = _fit_frames_to_length(frames, target_t)
|
||||
|
||||
# The real MuseTalk call signature varies. Most common is a method
|
||||
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
|
||||
for method_name in ("run", "infer", "lip_sync", "__call__"):
|
||||
method = getattr(self._infer, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
try:
|
||||
result = method(
|
||||
frames=frames,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
fps=fps,
|
||||
)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
# Try positional
|
||||
try:
|
||||
result = method(frames, audio, sample_rate, fps)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk wrapper could not find a working inference method. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
|
||||
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
|
||||
# realtime_inference.py::Avatar.inference) needs:
|
||||
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
|
||||
# - face_alignment (bbox)
|
||||
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
|
||||
# - Whisper encoder (openai/whisper-tiny)
|
||||
# - face_parsing weights
|
||||
# Plus its preprocessing module has import-time side effects that
|
||||
# load dwpose weights from a CWD-relative path. Turn the full
|
||||
# pipeline on by extending this method once those deps are
|
||||
# installed and weights are resolved — until then, callers should
|
||||
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
|
||||
# will skip the lip-sync pass.
|
||||
raise NotImplementedError(
|
||||
"MuseTalk lip-sync pipeline is not wired up yet. "
|
||||
"Set config.video.musetalk.enabled=false to bypass."
|
||||
)
|
||||
|
||||
|
||||
|
||||
+145
-187
@@ -1,4 +1,4 @@
|
||||
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
|
||||
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
@@ -22,15 +22,12 @@ Model weights are loaded once at construction and held resident across turns
|
||||
so reflective mode doesn't re-pay the load cost each reply.
|
||||
|
||||
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with
|
||||
quantised checkpoints from dit_repo.
|
||||
- dit_repo (configurable) — quantised DIT checkpoints. Supported
|
||||
formats:
|
||||
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
|
||||
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards are SKIPPED via
|
||||
ignore_patterns — replaced by the GGUF
|
||||
checkpoint from dit_repo.
|
||||
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
|
||||
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -49,27 +46,20 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# --- fp8 distill filenames --------------------------------------------------
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
|
||||
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
|
||||
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
|
||||
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
|
||||
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
|
||||
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
|
||||
|
||||
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
|
||||
# quantised files from dit_repo replace the DIT weights entirely. We must
|
||||
# keep the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
|
||||
# tokenizer support files. We only need the latter — the GGUF from dit_repo
|
||||
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
"high_noise_model/*.safetensors",
|
||||
"low_noise_model/*.safetensors",
|
||||
"*.pt",
|
||||
"diffusion_pytorch_model*.safetensors",
|
||||
"assets/*",
|
||||
"examples/*",
|
||||
"nohup.out",
|
||||
@@ -77,6 +67,40 @@ BASE_REPO_IGNORE_PATTERNS = [
|
||||
]
|
||||
|
||||
|
||||
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
|
||||
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
|
||||
|
||||
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
|
||||
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
|
||||
and are missed by the structured sweep. This generic traversal catches
|
||||
them. Bounded depth + visited-set to avoid cycles.
|
||||
"""
|
||||
import torch
|
||||
if visited is None:
|
||||
visited = set()
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited or depth > 6:
|
||||
return 0
|
||||
visited.add(obj_id)
|
||||
n = 0
|
||||
for attr_name in dir(obj):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
val = getattr(obj, attr_name)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
|
||||
try:
|
||||
setattr(obj, attr_name, val.to(torch.float16))
|
||||
n += 1
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(val, "__dict__") and not callable(val):
|
||||
n += _cast_all_fp32_tensors(val, visited, depth + 1)
|
||||
return n
|
||||
|
||||
|
||||
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||
|
||||
@@ -132,13 +156,11 @@ def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
|
||||
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
|
||||
|
||||
Supports two DIT quantisation formats:
|
||||
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
|
||||
``lightx2v/Wan2.2-Distill-Models``)
|
||||
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
|
||||
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
|
||||
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
|
||||
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
|
||||
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
@@ -150,11 +172,11 @@ class Wan22Pipeline:
|
||||
base_repo: str,
|
||||
dit_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
model_cls: str = "wan2.2",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
dit_quant_scheme: str = "fp8-sgl",
|
||||
t5_quantized: bool = False,
|
||||
dit_quant_scheme: str = "gguf-Q8_0",
|
||||
t5_quantized: bool = True,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.dit_repo = dit_repo
|
||||
@@ -167,10 +189,15 @@ class Wan22Pipeline:
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
if not self._is_gguf:
|
||||
raise ValueError(
|
||||
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
|
||||
)
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
|
||||
self._dit_ckpt = self._ensure_dit_checkpoint(
|
||||
dit_repo, dit_quant_scheme,
|
||||
)
|
||||
self._t5_fp8_ckpt = (
|
||||
@@ -306,52 +333,53 @@ class Wan22Pipeline:
|
||||
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||
|
||||
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||
"""Cast leftover fp32 DIT pre/post weights to fp16.
|
||||
"""Cast leftover fp32 DIT weights to fp16 (dense model).
|
||||
|
||||
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
|
||||
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
|
||||
up loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
|
||||
runs uniformly in fp16.
|
||||
GGUF dequantises the transformer blocks to fp16, but a handful of
|
||||
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
|
||||
loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
|
||||
directly at ``runner.model`` (no MoE wrapper). After the structured
|
||||
pre/post weight sweep, we also run a recursive traversal to catch
|
||||
fp32 conv3d biases etc. that live outside pre/post_weight.
|
||||
"""
|
||||
import torch
|
||||
|
||||
runner = self._runner
|
||||
models = getattr(runner.model, "model", None)
|
||||
if models is None:
|
||||
return
|
||||
if not isinstance(models, (list, tuple)):
|
||||
models = [models]
|
||||
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
|
||||
n_extra = _cast_all_fp32_tensors(runner.model)
|
||||
log.info(
|
||||
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
|
||||
n_struct, n_extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cast_fp32_dit_weights_in_model(m) -> int:
|
||||
import torch
|
||||
n_cast = 0
|
||||
for m in models:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
# Preserve pinned-memory status on pin_* tensors so
|
||||
# move_attr_to_cuda's non-blocking H2D copy is safe.
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
return n_cast
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@@ -374,46 +402,34 @@ class Wan22Pipeline:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_dit_checkpoints(
|
||||
def _ensure_dit_checkpoint(
|
||||
dit_repo: str,
|
||||
dit_quant_scheme: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the DIT pair.
|
||||
|
||||
Supports both fp8 safetensors and GGUF formats.
|
||||
"""
|
||||
) -> str:
|
||||
"""Return the local path to the single dense GGUF DIT checkpoint."""
|
||||
if not dit_repo:
|
||||
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||
if not dit_quant_scheme.startswith("gguf-"):
|
||||
raise ValueError(
|
||||
f"Only GGUF quant schemes are supported for dense 5B Turbo "
|
||||
f"(got {dit_quant_scheme!r})."
|
||||
)
|
||||
|
||||
is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
|
||||
|
||||
if is_gguf:
|
||||
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
|
||||
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
|
||||
else:
|
||||
high_file = FP8_HIGH_NOISE_FILE
|
||||
low_file = FP8_LOW_NOISE_FILE
|
||||
|
||||
# Local directory?
|
||||
if os.path.isdir(dit_repo):
|
||||
high = os.path.join(dit_repo, high_file)
|
||||
low = os.path.join(dit_repo, low_file)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
path = os.path.join(dit_repo, filename)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(
|
||||
f"DIT checkpoints not found in {dit_repo}: expected "
|
||||
f"{high_file} and {low_file}"
|
||||
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
|
||||
)
|
||||
return high, low
|
||||
return path
|
||||
|
||||
# HuggingFace download.
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading %s DIT checkpoints from %s ...",
|
||||
log.info("Downloading %s DIT checkpoint from %s ...",
|
||||
dit_quant_scheme, dit_repo)
|
||||
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
|
||||
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
|
||||
return high, low
|
||||
return hf_hub_download(repo_id=dit_repo, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_t5_fp8() -> str:
|
||||
@@ -431,8 +447,7 @@ class Wan22Pipeline:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._dit_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._dit_low
|
||||
cfg["dit_quantized_ckpt"] = self._dit_ckpt
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
# T5 fp8 quantization.
|
||||
@@ -496,103 +511,46 @@ class Wan22Pipeline:
|
||||
# --- LoRA --------------------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
|
||||
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
|
||||
|
||||
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
|
||||
to route the LoRA to the correct expert.
|
||||
|
||||
With ``lazy_load`` the DIT models are ``None`` at this point, so
|
||||
runtime ``switch_lora`` is impossible. Instead we inject
|
||||
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
|
||||
the LoRAs are applied when the models materialise on first inference.
|
||||
|
||||
Without ``lazy_load`` (models already resident) we call
|
||||
``switch_lora`` with explicit high/low keyword args.
|
||||
Dense has a single DIT (no MoE experts), so ``target`` must be
|
||||
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
|
||||
so ``switch_lora`` would crash — we use the dynamic-apply path that
|
||||
merges LoRAs during GGUF dequant.
|
||||
"""
|
||||
if not specs:
|
||||
return
|
||||
|
||||
# Resolve every path up-front (may trigger HF download).
|
||||
resolved: list[tuple["LoRASpec", str]] = []
|
||||
for spec in specs:
|
||||
if spec.target != "both":
|
||||
raise ValueError(
|
||||
f"Dense 5B Turbo has a single DIT; LoRA target must be "
|
||||
f"'both' (got {spec.target!r})."
|
||||
)
|
||||
local_path = self._resolve_lora_path(spec.path)
|
||||
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
|
||||
spec.name or spec.path, spec.weight, spec.target,
|
||||
local_path)
|
||||
log.info(" LoRA %s → strength=%.2f (%s)",
|
||||
spec.name or spec.path, spec.weight, local_path)
|
||||
resolved.append((spec, local_path))
|
||||
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
# Build the lora_configs list that LightX2V's lazy-load path
|
||||
# reads inside MultiDistillModelStruct.infer().
|
||||
lora_cfgs = []
|
||||
for spec, local_path in resolved:
|
||||
# LightX2V expects name "high_noise_model" / "low_noise_model"
|
||||
cfg_name = {
|
||||
"high_noise": "high_noise_model",
|
||||
"low_noise": "low_noise_model",
|
||||
}.get(spec.target)
|
||||
if cfg_name is None:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
lora_cfgs.append({
|
||||
"name": cfg_name,
|
||||
"path": local_path,
|
||||
"strength": spec.weight,
|
||||
})
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
else:
|
||||
# Models are loaded — use runtime hot-swap.
|
||||
high_path = high_strength = None
|
||||
low_path = low_strength = None
|
||||
for spec, local_path in resolved:
|
||||
if spec.target == "high_noise":
|
||||
high_path, high_strength = local_path, spec.weight
|
||||
elif spec.target == "low_noise":
|
||||
low_path, low_strength = local_path, spec.weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
|
||||
kwargs: dict = {}
|
||||
if high_path is not None:
|
||||
kwargs["high_lora_path"] = high_path
|
||||
kwargs["high_lora_strength"] = high_strength
|
||||
if low_path is not None:
|
||||
kwargs["low_lora_path"] = low_path
|
||||
kwargs["low_lora_strength"] = low_strength
|
||||
ok = self._runner.switch_lora(**kwargs)
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"runner.switch_lora returned False. Check that your "
|
||||
"LightX2V build supports runtime LoRA updates for "
|
||||
f"{self.model_cls}.")
|
||||
|
||||
lora_cfgs = [
|
||||
{"path": local_path, "strength": spec.weight}
|
||||
for spec, local_path in resolved
|
||||
]
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
self._applied_loras = list(specs)
|
||||
|
||||
def unload_loras(self) -> None:
|
||||
"""Remove all currently applied LoRAs."""
|
||||
if not self._applied_loras:
|
||||
return
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
# If models were materialised, drop them so the next inference
|
||||
# recreates them without LoRAs.
|
||||
model_struct = getattr(self._runner, "model", None)
|
||||
if model_struct is not None and hasattr(model_struct, "model"):
|
||||
for i in range(len(model_struct.model)):
|
||||
model_struct.model[i] = None
|
||||
else:
|
||||
self._runner.switch_lora("", 0.0)
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
self._applied_loras = []
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user