392 lines
15 KiB
Python
392 lines
15 KiB
Python
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
|
|
|
|
Top-level orchestrator. The heavy 3rd-party model code is isolated in
|
|
``server/video_models/`` so each wrapper can be updated independently.
|
|
|
|
This module is only imported by ``server/models.py`` when
|
|
``config.video.enabled`` is true. When disabled, the existing voice pipeline
|
|
is completely untouched.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Literal
|
|
|
|
import numpy as np
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
LoRATarget = Literal["high_noise", "low_noise", "both"]
|
|
|
|
|
|
@dataclass
|
|
class LoRASpec:
|
|
"""One LoRA adapter entry from ``config.video.loras``.
|
|
|
|
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
|
|
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
|
|
sub-model) and must be applied to the correct target. Allow
|
|
``target="both"`` for LoRAs that should be applied to both sub-models
|
|
(e.g. style LoRAs).
|
|
"""
|
|
|
|
path: str
|
|
weight: float = 1.0
|
|
target: LoRATarget = "both"
|
|
name: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class VideoConfig:
|
|
"""Flattened view of the ``video:`` section of config.yml."""
|
|
|
|
enabled: bool = False
|
|
backend: str = "lightx2v"
|
|
mode: str = "reflective" # "library" | "reflective"
|
|
resolution: int = 480
|
|
fps: int = 16
|
|
library_base_clip_count: int = 4
|
|
library_base_clip_seconds: int = 6
|
|
reflective_clip_seconds: int = 5
|
|
reflective_prompt_template: str = (
|
|
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
|
"natural lighting, soft focus background"
|
|
)
|
|
reflective_prompt_reply_words: int = 18
|
|
loras: list[LoRASpec] = field(default_factory=list)
|
|
|
|
# Model paths — can be overridden via config.yml.video.models.
|
|
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
|
# The bf16 DIT shards in this repo are skipped — we
|
|
# replace them with the fp8 files from wan22_fp8_repo.
|
|
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
|
|
# 4-step distilled DIT checkpoints (~15 GB each).
|
|
# wan22_config_json: path to the LightX2V inference config template the
|
|
# Wan22Pipeline will fill in with absolute ckpt paths.
|
|
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
|
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
|
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
|
wan22_model_cls: str = "wan2.2_moe_distill"
|
|
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
|
|
|
@classmethod
|
|
def from_dict(cls, raw: dict) -> "VideoConfig":
|
|
raw = raw or {}
|
|
library = raw.get("library", {}) or {}
|
|
reflective = raw.get("reflective", {}) or {}
|
|
models_raw = raw.get("models", {}) or {}
|
|
loras_raw = raw.get("loras") or []
|
|
|
|
default_template = (
|
|
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
|
"natural lighting, soft focus background"
|
|
)
|
|
|
|
loras: list[LoRASpec] = []
|
|
for entry in loras_raw:
|
|
if not entry or "path" not in entry:
|
|
continue
|
|
target = str(entry.get("target", "both")).lower()
|
|
if target not in ("high_noise", "low_noise", "both"):
|
|
log.warning(
|
|
"LoRA %s: invalid target %r, defaulting to 'both'",
|
|
entry.get("path"), target,
|
|
)
|
|
target = "both"
|
|
loras.append(
|
|
LoRASpec(
|
|
path=str(entry["path"]),
|
|
weight=float(entry.get("weight", 1.0)),
|
|
target=target, # type: ignore[arg-type]
|
|
name=entry.get("name"),
|
|
)
|
|
)
|
|
|
|
return cls(
|
|
enabled=bool(raw.get("enabled", False)),
|
|
backend=str(raw.get("backend", "lightx2v")),
|
|
mode=str(raw.get("mode", "reflective")),
|
|
resolution=int(raw.get("resolution", 480)),
|
|
fps=int(raw.get("fps", 16)),
|
|
library_base_clip_count=int(library.get("base_clip_count", 4)),
|
|
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
|
|
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
|
|
reflective_prompt_template=str(
|
|
reflective.get("clip_prompt_template", default_template)
|
|
),
|
|
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
|
|
loras=loras,
|
|
wan22_base_repo=str(
|
|
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
|
),
|
|
wan22_fp8_repo=str(
|
|
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
|
|
),
|
|
wan22_config_json=str(
|
|
models_raw.get(
|
|
"wan22_config_json",
|
|
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
|
|
)
|
|
),
|
|
wan22_model_cls=str(
|
|
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
|
|
),
|
|
musetalk_model_path=str(
|
|
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
|
|
),
|
|
)
|
|
|
|
|
|
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
|
|
# less repetitive when replayed. Kept module-level so tests can import them.
|
|
LIBRARY_BASE_PROMPTS = [
|
|
"webcam view of a person speaking, subtle head nods, casual expression, "
|
|
"natural lighting, soft focus background",
|
|
"webcam view of a person speaking, slight smile, gentle hand gesture, "
|
|
"natural lighting, soft focus background",
|
|
"webcam view of a person speaking, looking thoughtful, small head tilt, "
|
|
"natural lighting, soft focus background",
|
|
"webcam view of a person speaking, engaged and attentive, minor shoulder "
|
|
"movement, natural lighting, soft focus background",
|
|
"webcam view of a person speaking, relaxed posture, blinking naturally, "
|
|
"natural lighting, soft focus background",
|
|
]
|
|
|
|
IDLE_PROMPT = (
|
|
"webcam view of a person listening quietly, mouth closed, subtle "
|
|
"breathing, occasional blinks, calm expression, natural lighting, "
|
|
"soft focus background"
|
|
)
|
|
|
|
|
|
class VideoEngine:
|
|
"""Top-level video generation orchestrator.
|
|
|
|
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
|
|
pre-rendered clips. Exposed to ``ConversationSession`` via
|
|
``ModelManager.video_engine``.
|
|
"""
|
|
|
|
def __init__(self, cfg: VideoConfig):
|
|
self.cfg = cfg
|
|
self._lock = threading.Lock()
|
|
|
|
# Avatar state
|
|
self.avatar_path: str | None = None
|
|
self.idle_clip_mp4: bytes | None = None
|
|
# Pre-baked speaking base clips for library mode. Each entry is a
|
|
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
|
|
self.speaking_base_frames: list[np.ndarray] = []
|
|
# Round-robin pointer for picking a library clip per turn
|
|
self._library_cursor = 0
|
|
|
|
# Model wrappers — instantiated lazily by ``load_models()`` so unit
|
|
# tests can exercise VideoEngine without touching CUDA at all.
|
|
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
|
|
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
|
|
|
|
log.info(
|
|
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
|
|
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
|
|
)
|
|
|
|
# --- Model loading --------------------------------------------------
|
|
|
|
def load_models(self) -> None:
|
|
"""Instantiate the underlying model wrappers.
|
|
|
|
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
|
|
without triggering Wan2.2's ~12-16GB VRAM allocation.
|
|
"""
|
|
from server.video_models.wan22 import Wan22Pipeline
|
|
from server.video_models.musetalk import MuseTalkEngine
|
|
|
|
log.info(
|
|
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
|
|
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
|
|
)
|
|
self._wan22 = Wan22Pipeline(
|
|
base_repo=self.cfg.wan22_base_repo,
|
|
fp8_repo=self.cfg.wan22_fp8_repo,
|
|
config_json=self.cfg.wan22_config_json,
|
|
model_cls=self.cfg.wan22_model_cls,
|
|
resolution=self.cfg.resolution,
|
|
fps=self.cfg.fps,
|
|
)
|
|
if self.cfg.loras:
|
|
self._wan22.load_loras(self.cfg.loras)
|
|
log.info("Wan2.2 pipeline ready.")
|
|
|
|
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
|
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
|
log.info("MuseTalk engine ready.")
|
|
|
|
# --- Readiness ------------------------------------------------------
|
|
|
|
def is_ready(self) -> bool:
|
|
"""True when an avatar is set and a speaking clip can be produced."""
|
|
return (
|
|
self._wan22 is not None
|
|
and self._musetalk is not None
|
|
and self.avatar_path is not None
|
|
and self.idle_clip_mp4 is not None
|
|
)
|
|
|
|
# --- LoRA management ------------------------------------------------
|
|
|
|
def load_loras(self, specs: list[LoRASpec]) -> None:
|
|
"""Apply a list of LoRA adapters to the Wan2.2 base.
|
|
|
|
Replaces any previously applied LoRAs. Safe to call after init for
|
|
hot-reload via ``POST /api/reload-loras``.
|
|
"""
|
|
if self._wan22 is None:
|
|
raise RuntimeError("load_loras called before load_models()")
|
|
with self._lock:
|
|
self._wan22.unload_loras()
|
|
self._wan22.load_loras(specs)
|
|
self.cfg.loras = list(specs)
|
|
log.info("Applied %d LoRA(s): %s",
|
|
len(specs),
|
|
", ".join(s.name or s.path for s in specs) or "<none>")
|
|
|
|
# --- Avatar lifecycle ----------------------------------------------
|
|
|
|
def set_avatar(self, image_path: str) -> None:
|
|
"""Register an avatar image and pre-generate cached clips.
|
|
|
|
- Always: generate the idle loop.
|
|
- Library mode: also pre-generate ``library.base_clip_count``
|
|
speaking base clips.
|
|
- Reflective mode: idle loop only.
|
|
"""
|
|
if self._wan22 is None:
|
|
raise RuntimeError("set_avatar called before load_models()")
|
|
|
|
with self._lock:
|
|
log.info("Setting avatar: %s", image_path)
|
|
self.avatar_path = image_path
|
|
# Drop any previously cached clips so the new avatar's library
|
|
# doesn't mix with the old.
|
|
self.speaking_base_frames = []
|
|
self.idle_clip_mp4 = None
|
|
|
|
# Idle clip: short loop, neutral/listening prompt.
|
|
log.info("Generating idle clip...")
|
|
idle_frames = self._wan22.generate_i2v(
|
|
image_path=image_path,
|
|
prompt=IDLE_PROMPT,
|
|
seconds=self.cfg.library_base_clip_seconds,
|
|
seed=0,
|
|
)
|
|
from server.video_models.muxer import frames_to_mp4_loop
|
|
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
|
|
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
|
|
|
|
# Library mode: pre-bake N speaking base clips.
|
|
if self.cfg.mode == "library":
|
|
n = self.cfg.library_base_clip_count
|
|
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
|
|
for i in range(n):
|
|
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
|
|
frames = self._wan22.generate_i2v(
|
|
image_path=image_path,
|
|
prompt=prompt,
|
|
seconds=self.cfg.library_base_clip_seconds,
|
|
seed=i + 1,
|
|
)
|
|
self.speaking_base_frames.append(frames)
|
|
log.info(" base clip %d/%d rendered", i + 1, n)
|
|
|
|
self._library_cursor = 0
|
|
|
|
def get_idle_clip(self) -> bytes | None:
|
|
return self.idle_clip_mp4
|
|
|
|
# --- Per-turn generation -------------------------------------------
|
|
|
|
def generate_speaking_clip(
|
|
self,
|
|
audio_f32: np.ndarray,
|
|
sample_rate: int,
|
|
reply_text: str,
|
|
) -> bytes:
|
|
"""Produce a lip-synced MP4 for one assistant turn."""
|
|
if not self.is_ready():
|
|
raise RuntimeError(
|
|
"generate_speaking_clip: engine not ready "
|
|
"(avatar set? models loaded?)"
|
|
)
|
|
assert self._wan22 is not None
|
|
assert self._musetalk is not None
|
|
|
|
# 1. Source base frames.
|
|
if self.cfg.mode == "library":
|
|
base_frames = self._pick_library_frames(audio_f32, sample_rate)
|
|
else: # reflective
|
|
prompt = self._derive_prompt(reply_text)
|
|
log.info("Reflective prompt: %s", prompt[:120])
|
|
base_frames = self._wan22.generate_i2v(
|
|
image_path=self.avatar_path or "",
|
|
prompt=prompt,
|
|
seconds=self.cfg.reflective_clip_seconds,
|
|
seed=None, # random each turn
|
|
)
|
|
|
|
# 2. Lip-sync the base frames to the given audio.
|
|
synced_frames = self._musetalk.lip_sync(
|
|
frames=base_frames,
|
|
audio=audio_f32,
|
|
sample_rate=sample_rate,
|
|
fps=self.cfg.fps,
|
|
)
|
|
|
|
# 3. Mux frames + audio into an MP4.
|
|
from server.video_models.muxer import frames_and_audio_to_mp4
|
|
return frames_and_audio_to_mp4(
|
|
frames=synced_frames,
|
|
audio=audio_f32,
|
|
sample_rate=sample_rate,
|
|
fps=self.cfg.fps,
|
|
)
|
|
|
|
def _pick_library_frames(
|
|
self, audio_f32: np.ndarray, sample_rate: int
|
|
) -> np.ndarray:
|
|
"""Round-robin pick from the pre-baked library, clipped or looped
|
|
to roughly the audio's duration so there's no long freeze frame."""
|
|
if not self.speaking_base_frames:
|
|
raise RuntimeError(
|
|
"Library mode has no pre-baked base clips. "
|
|
"Was set_avatar called with mode=library?"
|
|
)
|
|
frames = self.speaking_base_frames[
|
|
self._library_cursor % len(self.speaking_base_frames)
|
|
]
|
|
self._library_cursor += 1
|
|
|
|
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
|
|
if target_frames <= 0:
|
|
return frames
|
|
if target_frames <= len(frames):
|
|
return frames[:target_frames]
|
|
# Loop (with a mirror tail to soften the seam) to cover longer audio.
|
|
loops = target_frames // len(frames) + 1
|
|
extended = np.concatenate([frames] * loops, axis=0)
|
|
return extended[:target_frames]
|
|
|
|
def _derive_prompt(self, reply_text: str) -> str:
|
|
"""Template-based prompt builder for reflective mode.
|
|
|
|
Takes up to ``prompt_reply_words`` words from the start of the reply
|
|
and interpolates them into the configured template. Cheap,
|
|
deterministic, no extra LLM call.
|
|
"""
|
|
words = (reply_text or "").split()
|
|
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
|
|
if not hint:
|
|
hint = "calm and friendly"
|
|
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
|