first stab at adding video
This commit is contained in:
+391
@@ -0,0 +1,391 @@
|
||||
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
|
||||
|
||||
Top-level orchestrator. The heavy 3rd-party model code is isolated in
|
||||
``server/video_models/`` so each wrapper can be updated independently.
|
||||
|
||||
This module is only imported by ``server/models.py`` when
|
||||
``config.video.enabled`` is true. When disabled, the existing voice pipeline
|
||||
is completely untouched.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LoRATarget = Literal["high_noise", "low_noise", "both"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRASpec:
|
||||
"""One LoRA adapter entry from ``config.video.loras``.
|
||||
|
||||
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
|
||||
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
|
||||
sub-model) and must be applied to the correct target. Allow
|
||||
``target="both"`` for LoRAs that should be applied to both sub-models
|
||||
(e.g. style LoRAs).
|
||||
"""
|
||||
|
||||
path: str
|
||||
weight: float = 1.0
|
||||
target: LoRATarget = "both"
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoConfig:
|
||||
"""Flattened view of the ``video:`` section of config.yml."""
|
||||
|
||||
enabled: bool = False
|
||||
backend: str = "lightx2v"
|
||||
mode: str = "reflective" # "library" | "reflective"
|
||||
resolution: int = 480
|
||||
fps: int = 16
|
||||
library_base_clip_count: int = 4
|
||||
library_base_clip_seconds: int = 6
|
||||
reflective_clip_seconds: int = 5
|
||||
reflective_prompt_template: str = (
|
||||
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
||||
"natural lighting, soft focus background"
|
||||
)
|
||||
reflective_prompt_reply_words: int = 18
|
||||
loras: list[LoRASpec] = field(default_factory=list)
|
||||
|
||||
# Model paths — can be overridden via config.yml.video.models.
|
||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||
# The bf16 DIT shards in this repo are skipped — we
|
||||
# replace them with the fp8 files from wan22_fp8_repo.
|
||||
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
|
||||
# 4-step distilled DIT checkpoints (~15 GB each).
|
||||
# wan22_config_json: path to the LightX2V inference config template the
|
||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
||||
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
wan22_model_cls: str = "wan2.2_moe_distill"
|
||||
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, raw: dict) -> "VideoConfig":
|
||||
raw = raw or {}
|
||||
library = raw.get("library", {}) or {}
|
||||
reflective = raw.get("reflective", {}) or {}
|
||||
models_raw = raw.get("models", {}) or {}
|
||||
loras_raw = raw.get("loras") or []
|
||||
|
||||
default_template = (
|
||||
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
||||
"natural lighting, soft focus background"
|
||||
)
|
||||
|
||||
loras: list[LoRASpec] = []
|
||||
for entry in loras_raw:
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
log.warning(
|
||||
"LoRA %s: invalid target %r, defaulting to 'both'",
|
||||
entry.get("path"), target,
|
||||
)
|
||||
target = "both"
|
||||
loras.append(
|
||||
LoRASpec(
|
||||
path=str(entry["path"]),
|
||||
weight=float(entry.get("weight", 1.0)),
|
||||
target=target, # type: ignore[arg-type]
|
||||
name=entry.get("name"),
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
enabled=bool(raw.get("enabled", False)),
|
||||
backend=str(raw.get("backend", "lightx2v")),
|
||||
mode=str(raw.get("mode", "reflective")),
|
||||
resolution=int(raw.get("resolution", 480)),
|
||||
fps=int(raw.get("fps", 16)),
|
||||
library_base_clip_count=int(library.get("base_clip_count", 4)),
|
||||
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
|
||||
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
|
||||
reflective_prompt_template=str(
|
||||
reflective.get("clip_prompt_template", default_template)
|
||||
),
|
||||
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
|
||||
loras=loras,
|
||||
wan22_base_repo=str(
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
||||
),
|
||||
wan22_fp8_repo=str(
|
||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
|
||||
),
|
||||
wan22_config_json=str(
|
||||
models_raw.get(
|
||||
"wan22_config_json",
|
||||
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
|
||||
)
|
||||
),
|
||||
wan22_model_cls=str(
|
||||
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
|
||||
),
|
||||
musetalk_model_path=str(
|
||||
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
|
||||
# less repetitive when replayed. Kept module-level so tests can import them.
|
||||
LIBRARY_BASE_PROMPTS = [
|
||||
"webcam view of a person speaking, subtle head nods, casual expression, "
|
||||
"natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, slight smile, gentle hand gesture, "
|
||||
"natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, looking thoughtful, small head tilt, "
|
||||
"natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, engaged and attentive, minor shoulder "
|
||||
"movement, natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, relaxed posture, blinking naturally, "
|
||||
"natural lighting, soft focus background",
|
||||
]
|
||||
|
||||
IDLE_PROMPT = (
|
||||
"webcam view of a person listening quietly, mouth closed, subtle "
|
||||
"breathing, occasional blinks, calm expression, natural lighting, "
|
||||
"soft focus background"
|
||||
)
|
||||
|
||||
|
||||
class VideoEngine:
|
||||
"""Top-level video generation orchestrator.
|
||||
|
||||
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
|
||||
pre-rendered clips. Exposed to ``ConversationSession`` via
|
||||
``ModelManager.video_engine``.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: VideoConfig):
|
||||
self.cfg = cfg
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Avatar state
|
||||
self.avatar_path: str | None = None
|
||||
self.idle_clip_mp4: bytes | None = None
|
||||
# Pre-baked speaking base clips for library mode. Each entry is a
|
||||
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
|
||||
self.speaking_base_frames: list[np.ndarray] = []
|
||||
# Round-robin pointer for picking a library clip per turn
|
||||
self._library_cursor = 0
|
||||
|
||||
# Model wrappers — instantiated lazily by ``load_models()`` so unit
|
||||
# tests can exercise VideoEngine without touching CUDA at all.
|
||||
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
|
||||
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
|
||||
|
||||
log.info(
|
||||
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
|
||||
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
|
||||
)
|
||||
|
||||
# --- Model loading --------------------------------------------------
|
||||
|
||||
def load_models(self) -> None:
|
||||
"""Instantiate the underlying model wrappers.
|
||||
|
||||
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
|
||||
without triggering Wan2.2's ~12-16GB VRAM allocation.
|
||||
"""
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
from server.video_models.musetalk import MuseTalkEngine
|
||||
|
||||
log.info(
|
||||
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
|
||||
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
|
||||
)
|
||||
self._wan22 = Wan22Pipeline(
|
||||
base_repo=self.cfg.wan22_base_repo,
|
||||
fp8_repo=self.cfg.wan22_fp8_repo,
|
||||
config_json=self.cfg.wan22_config_json,
|
||||
model_cls=self.cfg.wan22_model_cls,
|
||||
resolution=self.cfg.resolution,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
if self.cfg.loras:
|
||||
self._wan22.load_loras(self.cfg.loras)
|
||||
log.info("Wan2.2 pipeline ready.")
|
||||
|
||||
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||
log.info("MuseTalk engine ready.")
|
||||
|
||||
# --- Readiness ------------------------------------------------------
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True when an avatar is set and a speaking clip can be produced."""
|
||||
return (
|
||||
self._wan22 is not None
|
||||
and self._musetalk is not None
|
||||
and self.avatar_path is not None
|
||||
and self.idle_clip_mp4 is not None
|
||||
)
|
||||
|
||||
# --- LoRA management ------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list[LoRASpec]) -> None:
|
||||
"""Apply a list of LoRA adapters to the Wan2.2 base.
|
||||
|
||||
Replaces any previously applied LoRAs. Safe to call after init for
|
||||
hot-reload via ``POST /api/reload-loras``.
|
||||
"""
|
||||
if self._wan22 is None:
|
||||
raise RuntimeError("load_loras called before load_models()")
|
||||
with self._lock:
|
||||
self._wan22.unload_loras()
|
||||
self._wan22.load_loras(specs)
|
||||
self.cfg.loras = list(specs)
|
||||
log.info("Applied %d LoRA(s): %s",
|
||||
len(specs),
|
||||
", ".join(s.name or s.path for s in specs) or "<none>")
|
||||
|
||||
# --- Avatar lifecycle ----------------------------------------------
|
||||
|
||||
def set_avatar(self, image_path: str) -> None:
|
||||
"""Register an avatar image and pre-generate cached clips.
|
||||
|
||||
- Always: generate the idle loop.
|
||||
- Library mode: also pre-generate ``library.base_clip_count``
|
||||
speaking base clips.
|
||||
- Reflective mode: idle loop only.
|
||||
"""
|
||||
if self._wan22 is None:
|
||||
raise RuntimeError("set_avatar called before load_models()")
|
||||
|
||||
with self._lock:
|
||||
log.info("Setting avatar: %s", image_path)
|
||||
self.avatar_path = image_path
|
||||
# Drop any previously cached clips so the new avatar's library
|
||||
# doesn't mix with the old.
|
||||
self.speaking_base_frames = []
|
||||
self.idle_clip_mp4 = None
|
||||
|
||||
# Idle clip: short loop, neutral/listening prompt.
|
||||
log.info("Generating idle clip...")
|
||||
idle_frames = self._wan22.generate_i2v(
|
||||
image_path=image_path,
|
||||
prompt=IDLE_PROMPT,
|
||||
seconds=self.cfg.library_base_clip_seconds,
|
||||
seed=0,
|
||||
)
|
||||
from server.video_models.muxer import frames_to_mp4_loop
|
||||
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
|
||||
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
|
||||
|
||||
# Library mode: pre-bake N speaking base clips.
|
||||
if self.cfg.mode == "library":
|
||||
n = self.cfg.library_base_clip_count
|
||||
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
|
||||
for i in range(n):
|
||||
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
|
||||
frames = self._wan22.generate_i2v(
|
||||
image_path=image_path,
|
||||
prompt=prompt,
|
||||
seconds=self.cfg.library_base_clip_seconds,
|
||||
seed=i + 1,
|
||||
)
|
||||
self.speaking_base_frames.append(frames)
|
||||
log.info(" base clip %d/%d rendered", i + 1, n)
|
||||
|
||||
self._library_cursor = 0
|
||||
|
||||
def get_idle_clip(self) -> bytes | None:
|
||||
return self.idle_clip_mp4
|
||||
|
||||
# --- Per-turn generation -------------------------------------------
|
||||
|
||||
def generate_speaking_clip(
|
||||
self,
|
||||
audio_f32: np.ndarray,
|
||||
sample_rate: int,
|
||||
reply_text: str,
|
||||
) -> bytes:
|
||||
"""Produce a lip-synced MP4 for one assistant turn."""
|
||||
if not self.is_ready():
|
||||
raise RuntimeError(
|
||||
"generate_speaking_clip: engine not ready "
|
||||
"(avatar set? models loaded?)"
|
||||
)
|
||||
assert self._wan22 is not None
|
||||
assert self._musetalk is not None
|
||||
|
||||
# 1. Source base frames.
|
||||
if self.cfg.mode == "library":
|
||||
base_frames = self._pick_library_frames(audio_f32, sample_rate)
|
||||
else: # reflective
|
||||
prompt = self._derive_prompt(reply_text)
|
||||
log.info("Reflective prompt: %s", prompt[:120])
|
||||
base_frames = self._wan22.generate_i2v(
|
||||
image_path=self.avatar_path or "",
|
||||
prompt=prompt,
|
||||
seconds=self.cfg.reflective_clip_seconds,
|
||||
seed=None, # random each turn
|
||||
)
|
||||
|
||||
# 2. Lip-sync the base frames to the given audio.
|
||||
synced_frames = self._musetalk.lip_sync(
|
||||
frames=base_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
|
||||
# 3. Mux frames + audio into an MP4.
|
||||
from server.video_models.muxer import frames_and_audio_to_mp4
|
||||
return frames_and_audio_to_mp4(
|
||||
frames=synced_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
|
||||
def _pick_library_frames(
|
||||
self, audio_f32: np.ndarray, sample_rate: int
|
||||
) -> np.ndarray:
|
||||
"""Round-robin pick from the pre-baked library, clipped or looped
|
||||
to roughly the audio's duration so there's no long freeze frame."""
|
||||
if not self.speaking_base_frames:
|
||||
raise RuntimeError(
|
||||
"Library mode has no pre-baked base clips. "
|
||||
"Was set_avatar called with mode=library?"
|
||||
)
|
||||
frames = self.speaking_base_frames[
|
||||
self._library_cursor % len(self.speaking_base_frames)
|
||||
]
|
||||
self._library_cursor += 1
|
||||
|
||||
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
|
||||
if target_frames <= 0:
|
||||
return frames
|
||||
if target_frames <= len(frames):
|
||||
return frames[:target_frames]
|
||||
# Loop (with a mirror tail to soften the seam) to cover longer audio.
|
||||
loops = target_frames // len(frames) + 1
|
||||
extended = np.concatenate([frames] * loops, axis=0)
|
||||
return extended[:target_frames]
|
||||
|
||||
def _derive_prompt(self, reply_text: str) -> str:
|
||||
"""Template-based prompt builder for reflective mode.
|
||||
|
||||
Takes up to ``prompt_reply_words`` words from the start of the reply
|
||||
and interpolates them into the configured template. Cheap,
|
||||
deterministic, no extra LLM call.
|
||||
"""
|
||||
words = (reply_text or "").split()
|
||||
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
|
||||
if not hint:
|
||||
hint = "calm and friendly"
|
||||
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
|
||||
Reference in New Issue
Block a user