first stab at adding video

This commit is contained in:
2026-04-12 04:11:52 -04:00
parent 680c5b04cc
commit 2818b41004
37 changed files with 2982 additions and 24 deletions
+121 -2
View File
@@ -1,23 +1,27 @@
import json
import logging
import os
import tempfile
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.params import Form
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from server.audio_utils import pcm_bytes_to_float32
from server.models import ModelManager
from server.pipeline import ConversationSession
from server.video import LoRASpec
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
os.makedirs(AVATAR_DIR, exist_ok=True)
model_mgr = ModelManager()
@@ -47,6 +51,110 @@ async def set_voice(voice: str = Form(...), lang: str = Form("a")):
return {"status": "ok", "voice": voice}
# --- Video / avatar endpoints ---------------------------------------------
def _require_video() -> "object":
"""Return the video engine, or raise 404 if video mode isn't enabled."""
ve = model_mgr.video_engine
if ve is None:
raise HTTPException(
status_code=404,
detail="Video engine disabled. Set config.video.enabled=true and restart.",
)
return ve
@app.post("/api/set-avatar")
async def set_avatar(image: UploadFile):
"""Upload an avatar image and (re)generate cached clips."""
ve = _require_video()
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
with open(dest, "wb") as f:
f.write(await image.read())
log.info("Avatar saved to %s", dest)
import asyncio
try:
await asyncio.to_thread(ve.set_avatar, dest)
except Exception as e:
log.exception("set_avatar failed")
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
return {
"status": "ok",
"avatar_path": dest,
"idle_clip_url": "/api/idle-clip",
"mode": ve.cfg.mode,
}
@app.get("/api/idle-clip")
async def idle_clip():
"""Return the cached idle loop MP4."""
ve = _require_video()
data = ve.get_idle_clip()
if data is None:
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
return Response(content=data, media_type="video/mp4")
@app.post("/api/set-video-mode")
async def set_video_mode(mode: str = Form(...)):
"""Switch between 'off', 'library', and 'reflective'.
'off' leaves the video engine loaded but makes the pipeline take the
PCM streaming path on subsequent turns (by marking the engine not-ready
from the client's perspective via a simple flag).
"""
ve = _require_video()
if mode not in ("off", "library", "reflective"):
raise HTTPException(
status_code=400,
detail="mode must be one of: off, library, reflective",
)
# Switching between library/reflective changes how set_avatar prebakes
# clips. Require a fresh avatar upload afterwards to re-bake.
ve.cfg.mode = mode
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
@app.post("/api/reload-loras")
async def reload_loras(body: dict):
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
Regenerates the idle clip if an avatar is already set, since the new
LoRAs change the base style.
"""
ve = _require_video()
raw = body.get("loras") or []
specs: list[LoRASpec] = []
for entry in raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
target = "both"
specs.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
import asyncio
try:
await asyncio.to_thread(ve.load_loras, specs)
if ve.avatar_path:
log.info("Regenerating idle clip after LoRA reload.")
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
except Exception as e:
log.exception("reload_loras failed")
raise HTTPException(status_code=500, detail=str(e))
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
@app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket):
await ws.accept()
@@ -61,6 +169,17 @@ async def websocket_chat(ws: WebSocket):
session = ConversationSession(model_mgr, send_json, send_bytes)
await session.start()
# Tell the client whether video mode is active so it knows whether to
# suppress PCM playback and wait for speaking_clip messages instead.
ve = model_mgr.video_engine
await send_json({
"type": "video_mode",
"enabled": ve is not None,
"ready": ve.is_ready() if ve is not None else False,
"mode": ve.cfg.mode if ve is not None else "off",
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
})
try:
while True:
message = await ws.receive()
+26 -2
View File
@@ -5,6 +5,7 @@ from server.vad import StreamingVAD
from server.asr import ASREngine
from server.llm import LLMEngine
from server.tts import TTSEngine
from server.video import VideoConfig, VideoEngine
log = logging.getLogger(__name__)
@@ -31,6 +32,7 @@ class ModelManager:
self.asr_engine: ASREngine | None = None
self.llm_engine: LLMEngine | None = None
self.tts_engine: TTSEngine | None = None
self.video_engine: VideoEngine | None = None
def load_all(self):
"""Load all models sequentially. Call from the main process."""
@@ -38,6 +40,7 @@ class ModelManager:
self._load_asr()
self._load_llm()
self._load_tts()
self._load_video()
log.info("All models loaded successfully.")
def _load_vad(self):
@@ -84,8 +87,8 @@ class ModelManager:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3.5-0.8B"
# model_name = "Qwen/Qwen3.5-0.8B"
model_name = "dphn/Dolphin-X1-8B-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = get_device()
model = AutoModelForCausalLM.from_pretrained(
@@ -101,6 +104,27 @@ class ModelManager:
self.tts_engine = TTSEngine()
log.info("Kokoro TTS loaded.")
def _load_video(self):
"""Load the avatar video stack iff config.video.enabled is true.
Leaves ``video_engine`` as None when disabled so existing voice flow
is untouched. Later phases replace this stub with actual Wan2.2 +
MuseTalk loading inside ``VideoEngine``.
"""
from server.config import config
video_cfg_raw = config.get("video", {}) or {}
if not video_cfg_raw.get("enabled", False):
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
return
log.info("Loading avatar video engine...")
cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg)
if cfg.loras:
self.video_engine.load_loras(cfg.loras)
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session."""
return StreamingVAD(self.vad_model)
+73 -14
View File
@@ -157,11 +157,20 @@ class ConversationSession:
# TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"})
# Video-mode branch: if a video engine is loaded AND an avatar is
# set, buffer the full TTS output into a single blob, run MuseTalk
# lip-sync (library or reflective source), mux to MP4, and send the
# full clip + text in one shot. The client plays the MP4 (which
# carries audio) instead of the per-chunk PCM path.
video_engine = getattr(self.models, "video_engine", None)
use_video = video_engine is not None and video_engine.is_ready()
chunk_queue = queue.Queue()
self._last_played_chunk_id = None
segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments")
log.info(f"TTS: split response into {len(segments)} segments (video={use_video})")
def _tts_worker():
try:
@@ -187,6 +196,10 @@ class ConversationSession:
chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {}
# Video mode accumulator: we buffer all TTS audio into one float32
# array so MuseTalk can align against the full utterance.
audio_buffer: list[np.ndarray] = []
while True:
try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
@@ -202,23 +215,69 @@ class ConversationSession:
spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text
await self.send_json({
"type": "response_text",
"text": sentence_text,
"chunk_id": chunk_id,
"final": False,
})
pcm_bytes = float32_to_pcm_bytes(audio)
try:
await self.send_bytes(pcm_bytes)
except Exception:
log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set()
break
if use_video:
audio_buffer.append(audio)
# Don't stream text or PCM during video mode — we'll send
# everything after the clip renders so the client doesn't
# start displaying text before the video is ready.
else:
await self.send_json({
"type": "response_text",
"text": sentence_text,
"chunk_id": chunk_id,
"final": False,
})
pcm_bytes = float32_to_pcm_bytes(audio)
try:
await self.send_bytes(pcm_bytes)
except Exception:
log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set()
break
chunk_id += 1
tts_thread.join(timeout=2.0)
# Video mode: render the speaking clip now that TTS is done.
if use_video and audio_buffer and not self.cancel_event.is_set():
try:
full_audio = np.concatenate(audio_buffer).astype(np.float32)
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
log.info(
"Video: rendering speaking clip (audio=%ds, mode=%s)",
int(len(full_audio) / sample_rate), video_engine.cfg.mode,
)
mp4_bytes = await asyncio.to_thread(
video_engine.generate_speaking_clip,
full_audio,
sample_rate,
response,
)
if self.cancel_event.is_set():
log.info("Video clip discarded (cancelled during render).")
else:
duration_ms = int(len(full_audio) / sample_rate * 1000)
await self.send_json({
"type": "speaking_clip",
"chunk_id": 0,
"duration_ms": duration_ms,
"text": response,
"size_bytes": len(mp4_bytes),
})
await self.send_bytes(mp4_bytes)
except Exception:
log.exception("Video speaking-clip render failed; falling back silently.")
# Best-effort: tell the client nothing was spoken visually.
try:
await self.send_json({
"type": "response_text",
"text": response,
"chunk_id": 0,
"final": True,
})
except Exception:
pass
# Determine what was actually heard by the client
was_interrupted = spoken_text.strip() != response.strip()
if was_interrupted and self._last_played_chunk_id is not None:
+391
View File
@@ -0,0 +1,391 @@
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
Top-level orchestrator. The heavy 3rd-party model code is isolated in
``server/video_models/`` so each wrapper can be updated independently.
This module is only imported by ``server/models.py`` when
``config.video.enabled`` is true. When disabled, the existing voice pipeline
is completely untouched.
"""
from __future__ import annotations
import logging
import threading
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
log = logging.getLogger(__name__)
LoRATarget = Literal["high_noise", "low_noise", "both"]
@dataclass
class LoRASpec:
"""One LoRA adapter entry from ``config.video.loras``.
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
sub-model) and must be applied to the correct target. Allow
``target="both"`` for LoRAs that should be applied to both sub-models
(e.g. style LoRAs).
"""
path: str
weight: float = 1.0
target: LoRATarget = "both"
name: str | None = None
@dataclass
class VideoConfig:
"""Flattened view of the ``video:`` section of config.yml."""
enabled: bool = False
backend: str = "lightx2v"
mode: str = "reflective" # "library" | "reflective"
resolution: int = 480
fps: int = 16
library_base_clip_count: int = 4
library_base_clip_seconds: int = 6
reflective_clip_seconds: int = 5
reflective_prompt_template: str = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
reflective_prompt_reply_words: int = 18
loras: list[LoRASpec] = field(default_factory=list)
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with the fp8 files from wan22_fp8_repo.
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
# 4-step distilled DIT checkpoints (~15 GB each).
# wan22_config_json: path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
wan22_model_cls: str = "wan2.2_moe_distill"
musetalk_model_path: str = "TMElyralab/MuseTalk"
@classmethod
def from_dict(cls, raw: dict) -> "VideoConfig":
raw = raw or {}
library = raw.get("library", {}) or {}
reflective = raw.get("reflective", {}) or {}
models_raw = raw.get("models", {}) or {}
loras_raw = raw.get("loras") or []
default_template = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
loras: list[LoRASpec] = []
for entry in loras_raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
log.warning(
"LoRA %s: invalid target %r, defaulting to 'both'",
entry.get("path"), target,
)
target = "both"
loras.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
return cls(
enabled=bool(raw.get("enabled", False)),
backend=str(raw.get("backend", "lightx2v")),
mode=str(raw.get("mode", "reflective")),
resolution=int(raw.get("resolution", 480)),
fps=int(raw.get("fps", 16)),
library_base_clip_count=int(library.get("base_clip_count", 4)),
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
reflective_prompt_template=str(
reflective.get("clip_prompt_template", default_template)
),
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
loras=loras,
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
),
wan22_fp8_repo=str(
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
),
wan22_config_json=str(
models_raw.get(
"wan22_config_json",
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
)
),
wan22_model_cls=str(
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
),
musetalk_model_path=str(
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
),
)
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
# less repetitive when replayed. Kept module-level so tests can import them.
LIBRARY_BASE_PROMPTS = [
"webcam view of a person speaking, subtle head nods, casual expression, "
"natural lighting, soft focus background",
"webcam view of a person speaking, slight smile, gentle hand gesture, "
"natural lighting, soft focus background",
"webcam view of a person speaking, looking thoughtful, small head tilt, "
"natural lighting, soft focus background",
"webcam view of a person speaking, engaged and attentive, minor shoulder "
"movement, natural lighting, soft focus background",
"webcam view of a person speaking, relaxed posture, blinking naturally, "
"natural lighting, soft focus background",
]
IDLE_PROMPT = (
"webcam view of a person listening quietly, mouth closed, subtle "
"breathing, occasional blinks, calm expression, natural lighting, "
"soft focus background"
)
class VideoEngine:
"""Top-level video generation orchestrator.
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
pre-rendered clips. Exposed to ``ConversationSession`` via
``ModelManager.video_engine``.
"""
def __init__(self, cfg: VideoConfig):
self.cfg = cfg
self._lock = threading.Lock()
# Avatar state
self.avatar_path: str | None = None
self.idle_clip_mp4: bytes | None = None
# Pre-baked speaking base clips for library mode. Each entry is a
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
self.speaking_base_frames: list[np.ndarray] = []
# Round-robin pointer for picking a library clip per turn
self._library_cursor = 0
# Model wrappers — instantiated lazily by ``load_models()`` so unit
# tests can exercise VideoEngine without touching CUDA at all.
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
log.info(
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
)
# --- Model loading --------------------------------------------------
def load_models(self) -> None:
"""Instantiate the underlying model wrappers.
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
without triggering Wan2.2's ~12-16GB VRAM allocation.
"""
from server.video_models.wan22 import Wan22Pipeline
from server.video_models.musetalk import MuseTalkEngine
log.info(
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
)
self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo,
fp8_repo=self.cfg.wan22_fp8_repo,
config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution,
fps=self.cfg.fps,
)
if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras)
log.info("Wan2.2 pipeline ready.")
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
# --- Readiness ------------------------------------------------------
def is_ready(self) -> bool:
"""True when an avatar is set and a speaking clip can be produced."""
return (
self._wan22 is not None
and self._musetalk is not None
and self.avatar_path is not None
and self.idle_clip_mp4 is not None
)
# --- LoRA management ------------------------------------------------
def load_loras(self, specs: list[LoRASpec]) -> None:
"""Apply a list of LoRA adapters to the Wan2.2 base.
Replaces any previously applied LoRAs. Safe to call after init for
hot-reload via ``POST /api/reload-loras``.
"""
if self._wan22 is None:
raise RuntimeError("load_loras called before load_models()")
with self._lock:
self._wan22.unload_loras()
self._wan22.load_loras(specs)
self.cfg.loras = list(specs)
log.info("Applied %d LoRA(s): %s",
len(specs),
", ".join(s.name or s.path for s in specs) or "<none>")
# --- Avatar lifecycle ----------------------------------------------
def set_avatar(self, image_path: str) -> None:
"""Register an avatar image and pre-generate cached clips.
- Always: generate the idle loop.
- Library mode: also pre-generate ``library.base_clip_count``
speaking base clips.
- Reflective mode: idle loop only.
"""
if self._wan22 is None:
raise RuntimeError("set_avatar called before load_models()")
with self._lock:
log.info("Setting avatar: %s", image_path)
self.avatar_path = image_path
# Drop any previously cached clips so the new avatar's library
# doesn't mix with the old.
self.speaking_base_frames = []
self.idle_clip_mp4 = None
# Idle clip: short loop, neutral/listening prompt.
log.info("Generating idle clip...")
idle_frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=IDLE_PROMPT,
seconds=self.cfg.library_base_clip_seconds,
seed=0,
)
from server.video_models.muxer import frames_to_mp4_loop
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
# Library mode: pre-bake N speaking base clips.
if self.cfg.mode == "library":
n = self.cfg.library_base_clip_count
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
for i in range(n):
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=prompt,
seconds=self.cfg.library_base_clip_seconds,
seed=i + 1,
)
self.speaking_base_frames.append(frames)
log.info(" base clip %d/%d rendered", i + 1, n)
self._library_cursor = 0
def get_idle_clip(self) -> bytes | None:
return self.idle_clip_mp4
# --- Per-turn generation -------------------------------------------
def generate_speaking_clip(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> bytes:
"""Produce a lip-synced MP4 for one assistant turn."""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clip: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
assert self._musetalk is not None
# 1. Source base frames.
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(audio_f32, sample_rate)
else: # reflective
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt: %s", prompt[:120])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None, # random each turn
)
# 2. Lip-sync the base frames to the given audio.
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
# 3. Mux frames + audio into an MP4.
from server.video_models.muxer import frames_and_audio_to_mp4
return frames_and_audio_to_mp4(
frames=synced_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped or looped
to roughly the audio's duration so there's no long freeze frame."""
if not self.speaking_base_frames:
raise RuntimeError(
"Library mode has no pre-baked base clips. "
"Was set_avatar called with mode=library?"
)
frames = self.speaking_base_frames[
self._library_cursor % len(self.speaking_base_frames)
]
self._library_cursor += 1
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0:
return frames
if target_frames <= len(frames):
return frames[:target_frames]
# Loop (with a mirror tail to soften the seam) to cover longer audio.
loops = target_frames // len(frames) + 1
extended = np.concatenate([frames] * loops, axis=0)
return extended[:target_frames]
def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode.
Takes up to ``prompt_reply_words`` words from the start of the reply
and interpolates them into the configured template. Cheap,
deterministic, no extra LLM call.
"""
words = (reply_text or "").split()
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
if not hint:
hint = "calm and friendly"
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
+10
View File
@@ -0,0 +1,10 @@
"""Thin wrappers around 3rd-party video generation models.
Each submodule isolates one external dependency so the real API surface
can be updated in a single file without touching the pipeline.
Submodules:
- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V
- ``musetalk``: MuseTalk audio-driven lip-sync
- ``muxer``: ffmpeg-based frame/audio → MP4 encoding
"""
+164
View File
@@ -0,0 +1,164 @@
"""MuseTalk audio-driven lip-sync wrapper.
MuseTalk takes a sequence of face frames + driving audio and returns a new
sequence of frames where the mouth region is animated to match the audio.
This module isolates MuseTalk's real API behind a single ``lip_sync()``
method. MuseTalk's upstream Python surface varies between forks — if the
real import path or call signature differs, update this file only.
"""
from __future__ import annotations
import logging
import os
import numpy as np
log = logging.getLogger(__name__)
class MuseTalkEngine:
"""Thin wrapper over MuseTalk inference."""
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
self.model_path = model_path
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
# similar ``MuseTalkInfer`` class. Try the most common imports.
self._infer = self._load_impl(model_path)
log.info("MuseTalk engine loaded from %s", model_path)
@staticmethod
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
If none of the known entry points work the error message points at
this file so you know where to fix it.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
try:
from huggingface_hub import snapshot_download
resolved = snapshot_download(repo_id=model_path)
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
# Try upstream MuseTalk repo layout.
try:
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
return MuseTalkInference(model_path=resolved)
except ImportError:
pass
try:
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
return MuseTalkInfer(model_path=resolved)
except ImportError:
pass
try:
from musetalk import Inference # type: ignore[import-not-found]
return Inference(model_path=resolved)
except ImportError:
pass
raise RuntimeError(
"MuseTalk is installed but no known Python entry point was found. "
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
"to match the installed MuseTalk version."
)
# --- Inference ---------------------------------------------------------
def lip_sync(
self,
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> np.ndarray:
"""Return new frames with lip-sync applied to match ``audio``.
Args:
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
audio: float32 mono 1D audio.
sample_rate: sample rate of ``audio``.
fps: frame rate of ``frames``.
Returns:
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
to match audio duration at ``fps``.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
# Normalise frame count to audio duration so the caller doesn't have
# to do the arithmetic.
target_t = int(round(len(audio) / sample_rate * fps))
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# The real MuseTalk call signature varies. Most common is a method
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
for method_name in ("run", "infer", "lip_sync", "__call__"):
method = getattr(self._infer, method_name, None)
if method is None:
continue
try:
result = method(
frames=frames,
audio=audio,
sample_rate=sample_rate,
fps=fps,
)
return _ensure_uint8_rgb(result)
except TypeError:
# Try positional
try:
result = method(frames, audio, sample_rate, fps)
return _ensure_uint8_rgb(result)
except TypeError:
continue
raise RuntimeError(
"MuseTalk wrapper could not find a working inference method. "
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
)
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
Repeats with a ping-pong / boomerang tail so the seam between loops is
less jarring than a hard cut back to frame 0.
"""
if target_t <= 0:
return frames
t = len(frames)
if t == target_t:
return frames
if t > target_t:
return frames[:target_t]
# Extend via ping-pong looping
extended = [frames]
total = t
flip = True
while total < target_t:
seg = frames[::-1] if flip else frames
extended.append(seg)
total += t
flip = not flip
return np.concatenate(extended, axis=0)[:target_t]
def _ensure_uint8_rgb(arr) -> np.ndarray:
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
result = np.asarray(arr)
if result.dtype != np.uint8:
if result.dtype in (np.float32, np.float64):
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
if result.ndim == 3:
result = result[None, ...]
return result
+146
View File
@@ -0,0 +1,146 @@
"""ffmpeg-based frame + audio → MP4 muxing.
Uses the system ``ffmpeg`` binary already installed in the Dockerfile.
No extra python dependencies beyond ``numpy``.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import tempfile
import numpy as np
log = logging.getLogger(__name__)
def _ffmpeg_bin() -> str:
bin_path = shutil.which("ffmpeg")
if bin_path is None:
raise RuntimeError(
"ffmpeg binary not found on PATH. It should be installed by "
"the Dockerfile (line 13). Ensure you're running inside the "
"docker image or install ffmpeg locally."
)
return bin_path
def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]:
"""Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w)."""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
if frames.dtype != np.uint8:
frames = frames.astype(np.uint8)
with open(path, "wb") as f:
f.write(frames.tobytes())
_, h, w, _ = frames.shape
return h, w
def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None:
"""Write a float32 mono audio array to a 16-bit PCM WAV at ``path``."""
from scipy.io import wavfile # type: ignore[import-not-found]
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
wavfile.write(path, sample_rate, int16)
def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes:
"""Encode ``frames`` to a silent MP4 suitable for looping playback.
Used for the idle clip: no audio track, loopable on an HTMLMediaElement
without audible seams.
"""
if frames.size == 0:
raise ValueError("frames_to_mp4_loop: empty frames")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-an",
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer idle clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def frames_and_audio_to_mp4(
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> bytes:
"""Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio.
Used for per-turn speaking clips.
"""
if frames.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty frames")
if audio.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty audio")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
wav_path = os.path.join(td, "audio.wav")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
_write_wav(audio, sample_rate, wav_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-i", wav_path,
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-c:a", "aac",
"-b:a", "128k",
"-shortest",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer speaking clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def _run_ffmpeg(cmd: list[str]) -> None:
try:
proc = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace"))
raise
if proc.returncode != 0: # pragma: no cover
raise RuntimeError(f"ffmpeg returned {proc.returncode}")
+423
View File
@@ -0,0 +1,423 @@
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
config = set_config(args)
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — done ONCE
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
# LoRA hot-swap:
runner.switch_lora(lora_path, strength) # swap in
runner.switch_lora("", 0.0) # remove
Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with fp8.
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
the fp8 e4m3 4-step distilled high/low
noise DIT checkpoints (~15 GB each).
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import tempfile
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from server.video import LoRASpec
log = logging.getLogger(__name__)
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
# files from the distill repo replace the DIT weights entirely. We must keep
# the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/.
BASE_REPO_IGNORE_PATTERNS = [
"high_noise_model/*.safetensors",
"low_noise_model/*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
"*.md",
]
class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
``generate_i2v`` runs one inference turn against the already-loaded runner.
"""
def __init__(
self,
base_repo: str,
fp8_repo: str,
config_json: str,
model_cls: str = "wan2.2_moe_distill",
resolution: int = 480,
fps: int = 16,
):
self.base_repo = base_repo
self.fp8_repo = fp8_repo
self.config_json_template = config_json
self.model_cls = model_cls
self.resolution = resolution
self.fps = fps
self._applied_loras: list[LoRASpec] = []
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
self._model_root = self._ensure_base_repo(base_repo)
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
# 2. Materialize a runtime JSON config with absolute ckpt paths.
self._runtime_json_path = self._build_runtime_config()
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
args = self._build_args(
model_cls=model_cls,
model_path=self._model_root,
config_json=self._runtime_json_path,
)
# 4. set_config → init_runner. Runner construction triggers weight load.
# Imports are scoped here so ``import server.video_models.wan22``
# never pulls in lightx2v (tests can import this module on CPU).
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
from lightx2v.infer import init_runner # type: ignore[import-not-found]
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
model_cls, self._model_root)
self._config = set_config(args)
self._input_info_template = init_empty_input_info(
args.task, args.support_tasks
)
log.info("LightX2V init_runner — loading weights (this takes a while)...")
self._runner = init_runner(self._config)
log.info("LightX2V runner loaded; weights resident.")
# --- Weight provisioning -------------------------------------------------
@staticmethod
def _ensure_base_repo(base_repo: str) -> str:
"""Return a local directory containing the Wan2.2 base support files.
If ``base_repo`` is already a local directory, use it as-is. Otherwise
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
shards (they're replaced by the fp8 files).
"""
if os.path.isdir(base_repo):
return base_repo
from huggingface_hub import snapshot_download
log.info("Downloading Wan2.2 base support files from %s "
"(skipping bf16 DIT shards)...", base_repo)
return snapshot_download(
repo_id=base_repo,
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
)
@staticmethod
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
- If ``fp8_repo`` is a local directory, expect both files inside it.
- Otherwise treat it as a HF repo id and download only the two files
we need (not the ~150 GB of other variants in that repo).
"""
if not fp8_repo:
raise ValueError("fp8_repo must be a HF repo id or local directory.")
if os.path.isdir(fp8_repo):
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
if not (os.path.isfile(high) and os.path.isfile(low)):
raise FileNotFoundError(
f"fp8 checkpoints not found in {fp8_repo}: expected "
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
)
return high, low
from huggingface_hub import hf_hub_download
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
return high, low
def _build_runtime_config(self) -> str:
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
with open(self.config_json_template, "r", encoding="utf-8") as f:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._fp8_high
cfg["low_noise_quantized_ckpt"] = self._fp8_low
cfg.setdefault("fps", self.fps)
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_fp8_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
log.info("Runtime LightX2V config: %s", tmp.name)
return tmp.name
@staticmethod
def _build_args(
*, model_cls: str, model_path: str, config_json: str
) -> argparse.Namespace:
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
``set_config`` finds the attributes it expects. We only customize the
model/task/path fields; everything else stays at the CLI defaults.
"""
return argparse.Namespace(
seed=42,
model_cls=model_cls,
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
to route the LoRA to the correct expert.
With ``lazy_load`` the DIT models are ``None`` at this point, so
runtime ``switch_lora`` is impossible. Instead we inject
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
the LoRAs are applied when the models materialise on first inference.
Without ``lazy_load`` (models already resident) we call
``switch_lora`` with explicit high/low keyword args.
"""
if not specs:
return
# Resolve every path up-front (may trigger HF download).
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
spec.name or spec.path, spec.weight, spec.target,
local_path)
resolved.append((spec, local_path))
lazy = self._config.get("lazy_load", False)
if lazy:
# Build the lora_configs list that LightX2V's lazy-load path
# reads inside MultiDistillModelStruct.infer().
lora_cfgs = []
for spec, local_path in resolved:
# LightX2V expects name "high_noise_model" / "low_noise_model"
cfg_name = {
"high_noise": "high_noise_model",
"low_noise": "low_noise_model",
}.get(spec.target)
if cfg_name is None:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
lora_cfgs.append({
"name": cfg_name,
"path": local_path,
"strength": spec.weight,
})
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
else:
# Models are loaded — use runtime hot-swap.
high_path = high_strength = None
low_path = low_strength = None
for spec, local_path in resolved:
if spec.target == "high_noise":
high_path, high_strength = local_path, spec.weight
elif spec.target == "low_noise":
low_path, low_strength = local_path, spec.weight
else:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
kwargs: dict = {}
if high_path is not None:
kwargs["high_lora_path"] = high_path
kwargs["high_lora_strength"] = high_strength
if low_path is not None:
kwargs["low_lora_path"] = low_path
kwargs["low_lora_strength"] = low_strength
ok = self._runner.switch_lora(**kwargs)
if not ok:
raise RuntimeError(
"runner.switch_lora returned False. Check that your "
"LightX2V build supports runtime LoRA updates for "
f"{self.model_cls}.")
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
lazy = self._config.get("lazy_load", False)
if lazy:
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
# If models were materialised, drop them so the next inference
# recreates them without LoRAs.
model_struct = getattr(self._runner, "model", None)
if model_struct is not None and hasattr(model_struct, "model"):
for i in range(len(model_struct.model)):
model_struct.model[i] = None
else:
self._runner.switch_lora("", 0.0)
self._applied_loras = []
@staticmethod
def _resolve_lora_path(path: str) -> str:
"""Resolve a LoRA path. Supports:
- Absolute/relative local paths (returned as-is if the file exists)
- ``repo_id:filename`` HuggingFace references
"""
if os.path.isfile(path):
return path
if ":" in path and not path.startswith(("/", "./")):
repo_id, filename = path.split(":", 1)
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename)
return path
# --- Inference ---------------------------------------------------------
def generate_i2v(
self,
image_path: str,
prompt: str,
seconds: int,
seed: int | None = None,
negative_prompt: str = "",
) -> np.ndarray:
"""Run image-to-video inference and return decoded frames.
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
"""
if seed is None:
seed = random.randint(0, 2**31 - 1)
# Wan2.2 target_video_length is "frames including the conditioning
# frame", so N seconds → N*fps + 1.
target_frames = seconds * self.fps + 1
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
try:
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d%s",
prompt[:80], seconds, seed, out_path)
update_input_info_from_dict(
self._input_info_template,
{
"seed": seed,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_path": image_path,
"save_result_path": out_path,
"target_video_length": target_frames,
"return_result_tensor": False,
},
)
self._runner.run_pipeline(self._input_info_template)
return _read_mp4_to_frames(out_path)
finally:
try:
os.remove(out_path)
except OSError:
pass
# --- MP4 decoding helper ------------------------------------------------------
def _read_mp4_to_frames(path: str) -> np.ndarray:
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
try:
import imageio.v3 as iio # type: ignore[import-not-found]
frames = iio.imread(path, plugin="pyav")
arr = np.asarray(frames)
if arr.ndim == 3:
arr = arr[None, ...]
return arr.astype(np.uint8)
except Exception as e: # pragma: no cover - fallback path
log.warning("imageio decode failed (%s); falling back to cv2", e)
import cv2 # type: ignore[import-not-found]
cap = cv2.VideoCapture(path)
frames: list[np.ndarray] = []
while True:
ok, frame = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if not frames:
raise RuntimeError(f"Failed to decode any frames from {path}")
return np.stack(frames, axis=0).astype(np.uint8)