Files
bhetherman 44a10667c2 Enhance video handling and performance optimizations
- Added environment variables to prevent CPU thread pools from busy-waiting.
- Deferred loading of video models until first use to reduce VRAM footprint.
- Implemented streaming of speaking clips for improved responsiveness.
- Introduced a queue for managing speaking clips to handle multiple requests smoothly.
- Updated video playback logic to ensure proper handling of clip generation.
2026-04-24 00:36:18 -04:00

490 lines
19 KiB
Python

"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
Top-level orchestrator. The heavy 3rd-party model code is isolated in
``server/video_models/`` so each wrapper can be updated independently.
This module is only imported by ``server/models.py`` when
``config.video.enabled`` is true. When disabled, the existing voice pipeline
is completely untouched.
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
log = logging.getLogger(__name__)
LoRATarget = Literal["both"]
@dataclass
class LoRASpec:
"""One LoRA adapter entry from ``config.video.loras``.
The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE
experts), so ``target`` is always ``"both"``. The field is kept for
forward compatibility and config-file compatibility with older MoE
configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced
to ``"both"`` in ``VideoConfig.from_dict``.
"""
path: str
weight: float = 1.0
target: LoRATarget = "both"
name: str | None = None
@dataclass
class VideoConfig:
"""Flattened view of the ``video:`` section of config.yml."""
enabled: bool = False
backend: str = "lightx2v"
mode: str = "reflective" # "library" | "reflective"
resolution: int = 480
fps: int = 16
library_base_clip_count: int = 4
library_base_clip_seconds: int = 6
reflective_clip_seconds: int = 5
reflective_prompt_template: str = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
reflective_prompt_reply_words: int = 18
loras: list[LoRASpec] = field(default_factory=list)
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with a quantised GGUF from wan22_dit_repo.
# wan22_dit_repo : HF repo id (or local dir) providing the single
# dense GGUF DIT checkpoint (5B Turbo).
# wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default)
# or "gguf-Q4_K_M" for lower VRAM.
# wan22_config_json : path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B"
wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
wan22_dit_quant_scheme: str = "gguf-Q8_0"
wan22_t5_quantized: bool = True
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
wan22_model_cls: str = "wan2.2"
musetalk_enabled: bool = True
musetalk_model_path: str = "TMElyralab/MuseTalk"
@classmethod
def from_dict(cls, raw: dict) -> "VideoConfig":
raw = raw or {}
library = raw.get("library", {}) or {}
reflective = raw.get("reflective", {}) or {}
models_raw = raw.get("models", {}) or {}
loras_raw = raw.get("loras") or []
default_template = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
loras: list[LoRASpec] = []
for entry in loras_raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target != "both":
log.warning(
"LoRA %s: target %r is MoE-era; coercing to 'both' "
"(dense 5B has a single DIT).",
entry.get("path"), target,
)
target = "both"
loras.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
return cls(
enabled=bool(raw.get("enabled", False)),
backend=str(raw.get("backend", "lightx2v")),
mode=str(raw.get("mode", "reflective")),
resolution=int(raw.get("resolution", 480)),
fps=int(raw.get("fps", 16)),
library_base_clip_count=int(library.get("base_clip_count", 4)),
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
reflective_prompt_template=str(
reflective.get("clip_prompt_template", default_template)
),
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
loras=loras,
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B")
),
wan22_dit_repo=str(
models_raw.get(
"wan22_dit_repo",
"hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF",
)
),
wan22_dit_quant_scheme=str(
models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0")
),
wan22_t5_quantized=bool(
models_raw.get("wan22_t5_quantized", True)
),
wan22_config_json=str(
models_raw.get(
"wan22_config_json",
"/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json",
)
),
wan22_model_cls=str(
models_raw.get("wan22_model_cls", "wan2.2")
),
musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True))
if isinstance(raw.get("musetalk"), dict)
else bool(raw.get("musetalk_enabled", True)),
musetalk_model_path=str(
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
),
)
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
# less repetitive when replayed. Kept module-level so tests can import them.
LIBRARY_BASE_PROMPTS = [
"webcam view of a person speaking, subtle head nods, casual expression, "
"natural lighting, soft focus background",
"webcam view of a person speaking, slight smile, gentle hand gesture, "
"natural lighting, soft focus background",
"webcam view of a person speaking, looking thoughtful, small head tilt, "
"natural lighting, soft focus background",
"webcam view of a person speaking, engaged and attentive, minor shoulder "
"movement, natural lighting, soft focus background",
"webcam view of a person speaking, relaxed posture, blinking naturally, "
"natural lighting, soft focus background",
]
IDLE_PROMPT = (
"webcam view of a person listening quietly, mouth closed, subtle "
"breathing, occasional blinks, calm expression, natural lighting, "
"soft focus background"
)
class VideoEngine:
"""Top-level video generation orchestrator.
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
pre-rendered clips. Exposed to ``ConversationSession`` via
``ModelManager.video_engine``.
"""
def __init__(self, cfg: VideoConfig):
self.cfg = cfg
self._lock = threading.Lock()
# Avatar state
self.avatar_path: str | None = None
self.idle_clip_mp4: bytes | None = None
# Pre-baked speaking base clips for library mode. Each entry is a
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
self.speaking_base_frames: list[np.ndarray] = []
# Round-robin pointer for picking a library clip per turn
self._library_cursor = 0
# Model wrappers — instantiated lazily by ``load_models()`` so unit
# tests can exercise VideoEngine without touching CUDA at all.
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
log.info(
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
)
# --- Model loading --------------------------------------------------
def load_models(self) -> None:
"""Instantiate the underlying model wrappers.
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
without triggering Wan2.2's ~12-16GB VRAM allocation.
"""
from server.video_models.wan22 import Wan22Pipeline
from server.video_models.musetalk import MuseTalkEngine
log.info(
"Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
self.cfg.wan22_dit_quant_scheme,
)
self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo,
dit_repo=self.cfg.wan22_dit_repo,
config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution,
fps=self.cfg.fps,
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
t5_quantized=self.cfg.wan22_t5_quantized,
)
if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras)
log.info("Wan2.2 pipeline ready.")
if self.cfg.musetalk_enabled:
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
else:
log.info("MuseTalk disabled via config — skipping lip-sync pass.")
self._musetalk = None
# --- Readiness ------------------------------------------------------
def is_ready(self) -> bool:
"""True when an avatar is set and a speaking clip can be produced."""
musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None
return (
self._wan22 is not None
and musetalk_ok
and self.avatar_path is not None
and self.idle_clip_mp4 is not None
)
# --- LoRA management ------------------------------------------------
def load_loras(self, specs: list[LoRASpec]) -> None:
"""Apply a list of LoRA adapters to the Wan2.2 base.
Replaces any previously applied LoRAs. Safe to call after init for
hot-reload via ``POST /api/reload-loras``.
"""
if self._wan22 is None:
raise RuntimeError("load_loras called before load_models()")
with self._lock:
self._wan22.unload_loras()
self._wan22.load_loras(specs)
self.cfg.loras = list(specs)
log.info("Applied %d LoRA(s): %s",
len(specs),
", ".join(s.name or s.path for s in specs) or "<none>")
# --- Avatar lifecycle ----------------------------------------------
def set_avatar(self, image_path: str) -> None:
"""Register an avatar image and pre-generate cached clips.
- Always: generate the idle loop.
- Library mode: also pre-generate ``library.base_clip_count``
speaking base clips.
- Reflective mode: idle loop only.
Lazily calls load_models() on first invocation so that Wan2.2's VRAM
footprint doesn't exist until video is actually used.
"""
if self._wan22 is None:
self.load_models()
with self._lock:
log.info("Setting avatar: %s", image_path)
self.avatar_path = image_path
# Drop any previously cached clips so the new avatar's library
# doesn't mix with the old.
self.speaking_base_frames = []
self.idle_clip_mp4 = None
# Idle clip: short loop, neutral/listening prompt.
log.info("Generating idle clip...")
idle_frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=IDLE_PROMPT,
seconds=self.cfg.library_base_clip_seconds,
seed=0,
)
from server.video_models.muxer import frames_to_mp4_loop
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
# Library mode: pre-bake N speaking base clips.
if self.cfg.mode == "library":
n = self.cfg.library_base_clip_count
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
for i in range(n):
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=prompt,
seconds=self.cfg.library_base_clip_seconds,
seed=i + 1,
)
self.speaking_base_frames.append(frames)
log.info(" base clip %d/%d rendered", i + 1, n)
self._library_cursor = 0
def get_idle_clip(self) -> bytes | None:
return self.idle_clip_mp4
# --- Per-turn generation -------------------------------------------
def generate_speaking_clip(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> bytes:
"""Produce a lip-synced MP4 for one assistant turn."""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clip: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
# 1. Source base frames.
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(audio_f32, sample_rate)
else: # reflective
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt: %s", prompt[:120])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None, # random each turn
)
# 2. Lip-sync the base frames to the given audio (if enabled).
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
# 3. Mux frames + audio into an MP4.
from server.video_models.muxer import frames_and_audio_to_mp4
return frames_and_audio_to_mp4(
frames=synced_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped to the segment duration.
Does not loop frames — callers that need longer coverage should split
the audio into segments and call this once per segment.
"""
if not self.speaking_base_frames:
raise RuntimeError(
"Library mode has no pre-baked base clips. "
"Was set_avatar called with mode=library?"
)
frames = self.speaking_base_frames[
self._library_cursor % len(self.speaking_base_frames)
]
self._library_cursor += 1
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0:
return frames
return frames[:min(target_frames, len(frames))]
def generate_speaking_clips_streaming(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> Iterator[tuple[bytes, int]]:
"""Generate one MP4 per clip-length audio segment, yielding each when ready.
Splits ``audio_f32`` into segments of ``reflective_clip_seconds`` (or
``library_base_clip_seconds`` for library mode) and generates + lip-syncs
one clip per segment. Yields ``(mp4_bytes, duration_ms)`` tuples so the
caller can stream each clip to the client as soon as it's ready rather
than waiting for the full response.
"""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clips_streaming: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
if len(audio_f32) == 0:
return
clip_sec = (
self.cfg.library_base_clip_seconds
if self.cfg.mode == "library"
else self.cfg.reflective_clip_seconds
)
clip_samples = int(clip_sec * sample_rate)
segments = [
audio_f32[i : i + clip_samples]
for i in range(0, len(audio_f32), clip_samples)
]
for seg_audio in segments:
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(seg_audio, sample_rate)
else:
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt (clip segment): %s", prompt[:80])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None,
)
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
from server.video_models.muxer import frames_and_audio_to_mp4
mp4_bytes = frames_and_audio_to_mp4(
frames=synced_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
duration_ms = int(len(seg_audio) / sample_rate * 1000)
yield mp4_bytes, duration_ms
def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode.
Takes up to ``prompt_reply_words`` words from the start of the reply
and interpolates them into the configured template. Cheap,
deterministic, no extra LLM call.
"""
words = (reply_text or "").split()
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
if not hint:
hint = "calm and friendly"
return self.cfg.reflective_prompt_template.format(reply_hint=hint)