Files
live-voice-chat/server/video.py
T
2026-04-12 04:11:52 -04:00

392 lines
15 KiB
Python

"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
Top-level orchestrator. The heavy 3rd-party model code is isolated in
``server/video_models/`` so each wrapper can be updated independently.
This module is only imported by ``server/models.py`` when
``config.video.enabled`` is true. When disabled, the existing voice pipeline
is completely untouched.
"""
from __future__ import annotations
import logging
import threading
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
log = logging.getLogger(__name__)
LoRATarget = Literal["high_noise", "low_noise", "both"]
@dataclass
class LoRASpec:
"""One LoRA adapter entry from ``config.video.loras``.
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
sub-model) and must be applied to the correct target. Allow
``target="both"`` for LoRAs that should be applied to both sub-models
(e.g. style LoRAs).
"""
path: str
weight: float = 1.0
target: LoRATarget = "both"
name: str | None = None
@dataclass
class VideoConfig:
"""Flattened view of the ``video:`` section of config.yml."""
enabled: bool = False
backend: str = "lightx2v"
mode: str = "reflective" # "library" | "reflective"
resolution: int = 480
fps: int = 16
library_base_clip_count: int = 4
library_base_clip_seconds: int = 6
reflective_clip_seconds: int = 5
reflective_prompt_template: str = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
reflective_prompt_reply_words: int = 18
loras: list[LoRASpec] = field(default_factory=list)
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with the fp8 files from wan22_fp8_repo.
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
# 4-step distilled DIT checkpoints (~15 GB each).
# wan22_config_json: path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
wan22_model_cls: str = "wan2.2_moe_distill"
musetalk_model_path: str = "TMElyralab/MuseTalk"
@classmethod
def from_dict(cls, raw: dict) -> "VideoConfig":
raw = raw or {}
library = raw.get("library", {}) or {}
reflective = raw.get("reflective", {}) or {}
models_raw = raw.get("models", {}) or {}
loras_raw = raw.get("loras") or []
default_template = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
loras: list[LoRASpec] = []
for entry in loras_raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
log.warning(
"LoRA %s: invalid target %r, defaulting to 'both'",
entry.get("path"), target,
)
target = "both"
loras.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
return cls(
enabled=bool(raw.get("enabled", False)),
backend=str(raw.get("backend", "lightx2v")),
mode=str(raw.get("mode", "reflective")),
resolution=int(raw.get("resolution", 480)),
fps=int(raw.get("fps", 16)),
library_base_clip_count=int(library.get("base_clip_count", 4)),
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
reflective_prompt_template=str(
reflective.get("clip_prompt_template", default_template)
),
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
loras=loras,
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
),
wan22_fp8_repo=str(
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
),
wan22_config_json=str(
models_raw.get(
"wan22_config_json",
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
)
),
wan22_model_cls=str(
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
),
musetalk_model_path=str(
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
),
)
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
# less repetitive when replayed. Kept module-level so tests can import them.
LIBRARY_BASE_PROMPTS = [
"webcam view of a person speaking, subtle head nods, casual expression, "
"natural lighting, soft focus background",
"webcam view of a person speaking, slight smile, gentle hand gesture, "
"natural lighting, soft focus background",
"webcam view of a person speaking, looking thoughtful, small head tilt, "
"natural lighting, soft focus background",
"webcam view of a person speaking, engaged and attentive, minor shoulder "
"movement, natural lighting, soft focus background",
"webcam view of a person speaking, relaxed posture, blinking naturally, "
"natural lighting, soft focus background",
]
IDLE_PROMPT = (
"webcam view of a person listening quietly, mouth closed, subtle "
"breathing, occasional blinks, calm expression, natural lighting, "
"soft focus background"
)
class VideoEngine:
"""Top-level video generation orchestrator.
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
pre-rendered clips. Exposed to ``ConversationSession`` via
``ModelManager.video_engine``.
"""
def __init__(self, cfg: VideoConfig):
self.cfg = cfg
self._lock = threading.Lock()
# Avatar state
self.avatar_path: str | None = None
self.idle_clip_mp4: bytes | None = None
# Pre-baked speaking base clips for library mode. Each entry is a
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
self.speaking_base_frames: list[np.ndarray] = []
# Round-robin pointer for picking a library clip per turn
self._library_cursor = 0
# Model wrappers — instantiated lazily by ``load_models()`` so unit
# tests can exercise VideoEngine without touching CUDA at all.
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
log.info(
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
)
# --- Model loading --------------------------------------------------
def load_models(self) -> None:
"""Instantiate the underlying model wrappers.
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
without triggering Wan2.2's ~12-16GB VRAM allocation.
"""
from server.video_models.wan22 import Wan22Pipeline
from server.video_models.musetalk import MuseTalkEngine
log.info(
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
)
self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo,
fp8_repo=self.cfg.wan22_fp8_repo,
config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution,
fps=self.cfg.fps,
)
if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras)
log.info("Wan2.2 pipeline ready.")
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
# --- Readiness ------------------------------------------------------
def is_ready(self) -> bool:
"""True when an avatar is set and a speaking clip can be produced."""
return (
self._wan22 is not None
and self._musetalk is not None
and self.avatar_path is not None
and self.idle_clip_mp4 is not None
)
# --- LoRA management ------------------------------------------------
def load_loras(self, specs: list[LoRASpec]) -> None:
"""Apply a list of LoRA adapters to the Wan2.2 base.
Replaces any previously applied LoRAs. Safe to call after init for
hot-reload via ``POST /api/reload-loras``.
"""
if self._wan22 is None:
raise RuntimeError("load_loras called before load_models()")
with self._lock:
self._wan22.unload_loras()
self._wan22.load_loras(specs)
self.cfg.loras = list(specs)
log.info("Applied %d LoRA(s): %s",
len(specs),
", ".join(s.name or s.path for s in specs) or "<none>")
# --- Avatar lifecycle ----------------------------------------------
def set_avatar(self, image_path: str) -> None:
"""Register an avatar image and pre-generate cached clips.
- Always: generate the idle loop.
- Library mode: also pre-generate ``library.base_clip_count``
speaking base clips.
- Reflective mode: idle loop only.
"""
if self._wan22 is None:
raise RuntimeError("set_avatar called before load_models()")
with self._lock:
log.info("Setting avatar: %s", image_path)
self.avatar_path = image_path
# Drop any previously cached clips so the new avatar's library
# doesn't mix with the old.
self.speaking_base_frames = []
self.idle_clip_mp4 = None
# Idle clip: short loop, neutral/listening prompt.
log.info("Generating idle clip...")
idle_frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=IDLE_PROMPT,
seconds=self.cfg.library_base_clip_seconds,
seed=0,
)
from server.video_models.muxer import frames_to_mp4_loop
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
# Library mode: pre-bake N speaking base clips.
if self.cfg.mode == "library":
n = self.cfg.library_base_clip_count
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
for i in range(n):
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=prompt,
seconds=self.cfg.library_base_clip_seconds,
seed=i + 1,
)
self.speaking_base_frames.append(frames)
log.info(" base clip %d/%d rendered", i + 1, n)
self._library_cursor = 0
def get_idle_clip(self) -> bytes | None:
return self.idle_clip_mp4
# --- Per-turn generation -------------------------------------------
def generate_speaking_clip(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> bytes:
"""Produce a lip-synced MP4 for one assistant turn."""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clip: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
assert self._musetalk is not None
# 1. Source base frames.
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(audio_f32, sample_rate)
else: # reflective
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt: %s", prompt[:120])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None, # random each turn
)
# 2. Lip-sync the base frames to the given audio.
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
# 3. Mux frames + audio into an MP4.
from server.video_models.muxer import frames_and_audio_to_mp4
return frames_and_audio_to_mp4(
frames=synced_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped or looped
to roughly the audio's duration so there's no long freeze frame."""
if not self.speaking_base_frames:
raise RuntimeError(
"Library mode has no pre-baked base clips. "
"Was set_avatar called with mode=library?"
)
frames = self.speaking_base_frames[
self._library_cursor % len(self.speaking_base_frames)
]
self._library_cursor += 1
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0:
return frames
if target_frames <= len(frames):
return frames[:target_frames]
# Loop (with a mirror tail to soften the seam) to cover longer audio.
loops = target_frames // len(frames) + 1
extended = np.concatenate([frames] * loops, axis=0)
return extended[:target_frames]
def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode.
Takes up to ``prompt_reply_words`` words from the start of the reply
and interpolates them into the configured template. Cheap,
deterministic, no extra LLM call.
"""
words = (reply_text or "").split()
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
if not hint:
hint = "calm and friendly"
return self.cfg.reflective_prompt_template.format(reply_hint=hint)