"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync. Top-level orchestrator. The heavy 3rd-party model code is isolated in ``server/video_models/`` so each wrapper can be updated independently. This module is only imported by ``server/models.py`` when ``config.video.enabled`` is true. When disabled, the existing voice pipeline is completely untouched. """ from __future__ import annotations import logging import threading from dataclasses import dataclass, field from typing import Literal import numpy as np log = logging.getLogger(__name__) LoRATarget = Literal["both"] @dataclass class LoRASpec: """One LoRA adapter entry from ``config.video.loras``. The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE experts), so ``target`` is always ``"both"``. The field is kept for forward compatibility and config-file compatibility with older MoE configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced to ``"both"`` in ``VideoConfig.from_dict``. """ path: str weight: float = 1.0 target: LoRATarget = "both" name: str | None = None @dataclass class VideoConfig: """Flattened view of the ``video:`` section of config.yml.""" enabled: bool = False backend: str = "lightx2v" mode: str = "reflective" # "library" | "reflective" resolution: int = 480 fps: int = 16 library_base_clip_count: int = 4 library_base_clip_seconds: int = 6 reflective_clip_seconds: int = 5 reflective_prompt_template: str = ( "webcam view of a person speaking, {reply_hint}, casual gestures, " "natural lighting, soft focus background" ) reflective_prompt_reply_words: int = 18 loras: list[LoRASpec] = field(default_factory=list) # Model paths — can be overridden via config.yml.video.models. # wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer. # The bf16 DIT shards in this repo are skipped — we # replace them with a quantised GGUF from wan22_dit_repo. # wan22_dit_repo : HF repo id (or local dir) providing the single # dense GGUF DIT checkpoint (5B Turbo). # wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default) # or "gguf-Q4_K_M" for lower VRAM. # wan22_config_json : path to the LightX2V inference config template the # Wan22Pipeline will fill in with absolute ckpt paths. wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B" wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF" wan22_dit_quant_scheme: str = "gguf-Q8_0" wan22_t5_quantized: bool = True wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json" wan22_model_cls: str = "wan2.2" musetalk_enabled: bool = True musetalk_model_path: str = "TMElyralab/MuseTalk" @classmethod def from_dict(cls, raw: dict) -> "VideoConfig": raw = raw or {} library = raw.get("library", {}) or {} reflective = raw.get("reflective", {}) or {} models_raw = raw.get("models", {}) or {} loras_raw = raw.get("loras") or [] default_template = ( "webcam view of a person speaking, {reply_hint}, casual gestures, " "natural lighting, soft focus background" ) loras: list[LoRASpec] = [] for entry in loras_raw: if not entry or "path" not in entry: continue target = str(entry.get("target", "both")).lower() if target != "both": log.warning( "LoRA %s: target %r is MoE-era; coercing to 'both' " "(dense 5B has a single DIT).", entry.get("path"), target, ) target = "both" loras.append( LoRASpec( path=str(entry["path"]), weight=float(entry.get("weight", 1.0)), target=target, # type: ignore[arg-type] name=entry.get("name"), ) ) return cls( enabled=bool(raw.get("enabled", False)), backend=str(raw.get("backend", "lightx2v")), mode=str(raw.get("mode", "reflective")), resolution=int(raw.get("resolution", 480)), fps=int(raw.get("fps", 16)), library_base_clip_count=int(library.get("base_clip_count", 4)), library_base_clip_seconds=int(library.get("base_clip_seconds", 6)), reflective_clip_seconds=int(reflective.get("clip_seconds", 5)), reflective_prompt_template=str( reflective.get("clip_prompt_template", default_template) ), reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)), loras=loras, wan22_base_repo=str( models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B") ), wan22_dit_repo=str( models_raw.get( "wan22_dit_repo", "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF", ) ), wan22_dit_quant_scheme=str( models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0") ), wan22_t5_quantized=bool( models_raw.get("wan22_t5_quantized", True) ), wan22_config_json=str( models_raw.get( "wan22_config_json", "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json", ) ), wan22_model_cls=str( models_raw.get("wan22_model_cls", "wan2.2") ), musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True)) if isinstance(raw.get("musetalk"), dict) else bool(raw.get("musetalk_enabled", True)), musetalk_model_path=str( models_raw.get("musetalk_path", "TMElyralab/MuseTalk") ), ) # Library-mode base-clip prompts. Varied gestures so the pre-baked set feels # less repetitive when replayed. Kept module-level so tests can import them. LIBRARY_BASE_PROMPTS = [ "webcam view of a person speaking, subtle head nods, casual expression, " "natural lighting, soft focus background", "webcam view of a person speaking, slight smile, gentle hand gesture, " "natural lighting, soft focus background", "webcam view of a person speaking, looking thoughtful, small head tilt, " "natural lighting, soft focus background", "webcam view of a person speaking, engaged and attentive, minor shoulder " "movement, natural lighting, soft focus background", "webcam view of a person speaking, relaxed posture, blinking naturally, " "natural lighting, soft focus background", ] IDLE_PROMPT = ( "webcam view of a person listening quietly, mouth closed, subtle " "breathing, occasional blinks, calm expression, natural lighting, " "soft focus background" ) class VideoEngine: """Top-level video generation orchestrator. Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's pre-rendered clips. Exposed to ``ConversationSession`` via ``ModelManager.video_engine``. """ def __init__(self, cfg: VideoConfig): self.cfg = cfg self._lock = threading.Lock() # Avatar state self.avatar_path: str | None = None self.idle_clip_mp4: bytes | None = None # Pre-baked speaking base clips for library mode. Each entry is a # contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8. self.speaking_base_frames: list[np.ndarray] = [] # Round-robin pointer for picking a library clip per turn self._library_cursor = 0 # Model wrappers — instantiated lazily by ``load_models()`` so unit # tests can exercise VideoEngine without touching CUDA at all. self._wan22 = None # server.video_models.wan22.Wan22Pipeline self._musetalk = None # server.video_models.musetalk.MuseTalkEngine log.info( "VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).", cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras), ) # --- Model loading -------------------------------------------------- def load_models(self) -> None: """Instantiate the underlying model wrappers. Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk`` without triggering Wan2.2's ~12-16GB VRAM allocation. """ from server.video_models.wan22 import Wan22Pipeline from server.video_models.musetalk import MuseTalkEngine log.info( "Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...", self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo, self.cfg.wan22_dit_quant_scheme, ) self._wan22 = Wan22Pipeline( base_repo=self.cfg.wan22_base_repo, dit_repo=self.cfg.wan22_dit_repo, config_json=self.cfg.wan22_config_json, model_cls=self.cfg.wan22_model_cls, resolution=self.cfg.resolution, fps=self.cfg.fps, dit_quant_scheme=self.cfg.wan22_dit_quant_scheme, t5_quantized=self.cfg.wan22_t5_quantized, ) if self.cfg.loras: self._wan22.load_loras(self.cfg.loras) log.info("Wan2.2 pipeline ready.") if self.cfg.musetalk_enabled: log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path) self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path) log.info("MuseTalk engine ready.") else: log.info("MuseTalk disabled via config — skipping lip-sync pass.") self._musetalk = None # --- Readiness ------------------------------------------------------ def is_ready(self) -> bool: """True when an avatar is set and a speaking clip can be produced.""" musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None return ( self._wan22 is not None and musetalk_ok and self.avatar_path is not None and self.idle_clip_mp4 is not None ) # --- LoRA management ------------------------------------------------ def load_loras(self, specs: list[LoRASpec]) -> None: """Apply a list of LoRA adapters to the Wan2.2 base. Replaces any previously applied LoRAs. Safe to call after init for hot-reload via ``POST /api/reload-loras``. """ if self._wan22 is None: raise RuntimeError("load_loras called before load_models()") with self._lock: self._wan22.unload_loras() self._wan22.load_loras(specs) self.cfg.loras = list(specs) log.info("Applied %d LoRA(s): %s", len(specs), ", ".join(s.name or s.path for s in specs) or "") # --- Avatar lifecycle ---------------------------------------------- def set_avatar(self, image_path: str) -> None: """Register an avatar image and pre-generate cached clips. - Always: generate the idle loop. - Library mode: also pre-generate ``library.base_clip_count`` speaking base clips. - Reflective mode: idle loop only. """ if self._wan22 is None: raise RuntimeError("set_avatar called before load_models()") with self._lock: log.info("Setting avatar: %s", image_path) self.avatar_path = image_path # Drop any previously cached clips so the new avatar's library # doesn't mix with the old. self.speaking_base_frames = [] self.idle_clip_mp4 = None # Idle clip: short loop, neutral/listening prompt. log.info("Generating idle clip...") idle_frames = self._wan22.generate_i2v( image_path=image_path, prompt=IDLE_PROMPT, seconds=self.cfg.library_base_clip_seconds, seed=0, ) from server.video_models.muxer import frames_to_mp4_loop self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps) log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4)) # Library mode: pre-bake N speaking base clips. if self.cfg.mode == "library": n = self.cfg.library_base_clip_count log.info("Pre-baking %d speaking base clip(s) for library mode.", n) for i in range(n): prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)] frames = self._wan22.generate_i2v( image_path=image_path, prompt=prompt, seconds=self.cfg.library_base_clip_seconds, seed=i + 1, ) self.speaking_base_frames.append(frames) log.info(" base clip %d/%d rendered", i + 1, n) self._library_cursor = 0 def get_idle_clip(self) -> bytes | None: return self.idle_clip_mp4 # --- Per-turn generation ------------------------------------------- def generate_speaking_clip( self, audio_f32: np.ndarray, sample_rate: int, reply_text: str, ) -> bytes: """Produce a lip-synced MP4 for one assistant turn.""" if not self.is_ready(): raise RuntimeError( "generate_speaking_clip: engine not ready " "(avatar set? models loaded?)" ) assert self._wan22 is not None # 1. Source base frames. if self.cfg.mode == "library": base_frames = self._pick_library_frames(audio_f32, sample_rate) else: # reflective prompt = self._derive_prompt(reply_text) log.info("Reflective prompt: %s", prompt[:120]) base_frames = self._wan22.generate_i2v( image_path=self.avatar_path or "", prompt=prompt, seconds=self.cfg.reflective_clip_seconds, seed=None, # random each turn ) # 2. Lip-sync the base frames to the given audio (if enabled). if self._musetalk is not None: synced_frames = self._musetalk.lip_sync( frames=base_frames, audio=audio_f32, sample_rate=sample_rate, fps=self.cfg.fps, ) else: synced_frames = base_frames # 3. Mux frames + audio into an MP4. from server.video_models.muxer import frames_and_audio_to_mp4 return frames_and_audio_to_mp4( frames=synced_frames, audio=audio_f32, sample_rate=sample_rate, fps=self.cfg.fps, ) def _pick_library_frames( self, audio_f32: np.ndarray, sample_rate: int ) -> np.ndarray: """Round-robin pick from the pre-baked library, clipped or looped to roughly the audio's duration so there's no long freeze frame.""" if not self.speaking_base_frames: raise RuntimeError( "Library mode has no pre-baked base clips. " "Was set_avatar called with mode=library?" ) frames = self.speaking_base_frames[ self._library_cursor % len(self.speaking_base_frames) ] self._library_cursor += 1 target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps)) if target_frames <= 0: return frames if target_frames <= len(frames): return frames[:target_frames] # Loop (with a mirror tail to soften the seam) to cover longer audio. loops = target_frames // len(frames) + 1 extended = np.concatenate([frames] * loops, axis=0) return extended[:target_frames] def _derive_prompt(self, reply_text: str) -> str: """Template-based prompt builder for reflective mode. Takes up to ``prompt_reply_words`` words from the start of the reply and interpolates them into the configured template. Cheap, deterministic, no extra LLM call. """ words = (reply_text or "").split() hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip() if not hint: hint = "calm and friendly" return self.cfg.reflective_prompt_template.format(reply_hint=hint)