first stab at adding video
This commit is contained in:
+121
-2
@@ -1,23 +1,27 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.params import Form
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from server.audio_utils import pcm_bytes_to_float32
|
||||
from server.models import ModelManager
|
||||
from server.pipeline import ConversationSession
|
||||
from server.video import LoRASpec
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
||||
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
|
||||
os.makedirs(AVATAR_DIR, exist_ok=True)
|
||||
|
||||
model_mgr = ModelManager()
|
||||
|
||||
@@ -47,6 +51,110 @@ async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
||||
return {"status": "ok", "voice": voice}
|
||||
|
||||
|
||||
# --- Video / avatar endpoints ---------------------------------------------
|
||||
|
||||
def _require_video() -> "object":
|
||||
"""Return the video engine, or raise 404 if video mode isn't enabled."""
|
||||
ve = model_mgr.video_engine
|
||||
if ve is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Video engine disabled. Set config.video.enabled=true and restart.",
|
||||
)
|
||||
return ve
|
||||
|
||||
|
||||
@app.post("/api/set-avatar")
|
||||
async def set_avatar(image: UploadFile):
|
||||
"""Upload an avatar image and (re)generate cached clips."""
|
||||
ve = _require_video()
|
||||
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
|
||||
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
|
||||
with open(dest, "wb") as f:
|
||||
f.write(await image.read())
|
||||
log.info("Avatar saved to %s", dest)
|
||||
|
||||
import asyncio
|
||||
try:
|
||||
await asyncio.to_thread(ve.set_avatar, dest)
|
||||
except Exception as e:
|
||||
log.exception("set_avatar failed")
|
||||
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"avatar_path": dest,
|
||||
"idle_clip_url": "/api/idle-clip",
|
||||
"mode": ve.cfg.mode,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/idle-clip")
|
||||
async def idle_clip():
|
||||
"""Return the cached idle loop MP4."""
|
||||
ve = _require_video()
|
||||
data = ve.get_idle_clip()
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
|
||||
return Response(content=data, media_type="video/mp4")
|
||||
|
||||
|
||||
@app.post("/api/set-video-mode")
|
||||
async def set_video_mode(mode: str = Form(...)):
|
||||
"""Switch between 'off', 'library', and 'reflective'.
|
||||
|
||||
'off' leaves the video engine loaded but makes the pipeline take the
|
||||
PCM streaming path on subsequent turns (by marking the engine not-ready
|
||||
from the client's perspective via a simple flag).
|
||||
"""
|
||||
ve = _require_video()
|
||||
if mode not in ("off", "library", "reflective"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="mode must be one of: off, library, reflective",
|
||||
)
|
||||
# Switching between library/reflective changes how set_avatar prebakes
|
||||
# clips. Require a fresh avatar upload afterwards to re-bake.
|
||||
ve.cfg.mode = mode
|
||||
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
|
||||
|
||||
|
||||
@app.post("/api/reload-loras")
|
||||
async def reload_loras(body: dict):
|
||||
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
|
||||
|
||||
Regenerates the idle clip if an avatar is already set, since the new
|
||||
LoRAs change the base style.
|
||||
"""
|
||||
ve = _require_video()
|
||||
raw = body.get("loras") or []
|
||||
specs: list[LoRASpec] = []
|
||||
for entry in raw:
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
target = "both"
|
||||
specs.append(
|
||||
LoRASpec(
|
||||
path=str(entry["path"]),
|
||||
weight=float(entry.get("weight", 1.0)),
|
||||
target=target, # type: ignore[arg-type]
|
||||
name=entry.get("name"),
|
||||
)
|
||||
)
|
||||
import asyncio
|
||||
try:
|
||||
await asyncio.to_thread(ve.load_loras, specs)
|
||||
if ve.avatar_path:
|
||||
log.info("Regenerating idle clip after LoRA reload.")
|
||||
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
|
||||
except Exception as e:
|
||||
log.exception("reload_loras failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
|
||||
|
||||
|
||||
@app.websocket("/ws/chat")
|
||||
async def websocket_chat(ws: WebSocket):
|
||||
await ws.accept()
|
||||
@@ -61,6 +169,17 @@ async def websocket_chat(ws: WebSocket):
|
||||
session = ConversationSession(model_mgr, send_json, send_bytes)
|
||||
await session.start()
|
||||
|
||||
# Tell the client whether video mode is active so it knows whether to
|
||||
# suppress PCM playback and wait for speaking_clip messages instead.
|
||||
ve = model_mgr.video_engine
|
||||
await send_json({
|
||||
"type": "video_mode",
|
||||
"enabled": ve is not None,
|
||||
"ready": ve.is_ready() if ve is not None else False,
|
||||
"mode": ve.cfg.mode if ve is not None else "off",
|
||||
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await ws.receive()
|
||||
|
||||
+26
-2
@@ -5,6 +5,7 @@ from server.vad import StreamingVAD
|
||||
from server.asr import ASREngine
|
||||
from server.llm import LLMEngine
|
||||
from server.tts import TTSEngine
|
||||
from server.video import VideoConfig, VideoEngine
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +32,7 @@ class ModelManager:
|
||||
self.asr_engine: ASREngine | None = None
|
||||
self.llm_engine: LLMEngine | None = None
|
||||
self.tts_engine: TTSEngine | None = None
|
||||
self.video_engine: VideoEngine | None = None
|
||||
|
||||
def load_all(self):
|
||||
"""Load all models sequentially. Call from the main process."""
|
||||
@@ -38,6 +40,7 @@ class ModelManager:
|
||||
self._load_asr()
|
||||
self._load_llm()
|
||||
self._load_tts()
|
||||
self._load_video()
|
||||
log.info("All models loaded successfully.")
|
||||
|
||||
def _load_vad(self):
|
||||
@@ -84,8 +87,8 @@ class ModelManager:
|
||||
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Qwen/Qwen3.5-0.8B"
|
||||
|
||||
# model_name = "Qwen/Qwen3.5-0.8B"
|
||||
model_name = "dphn/Dolphin-X1-8B-FP8"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
device = get_device()
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -101,6 +104,27 @@ class ModelManager:
|
||||
self.tts_engine = TTSEngine()
|
||||
log.info("Kokoro TTS loaded.")
|
||||
|
||||
def _load_video(self):
|
||||
"""Load the avatar video stack iff config.video.enabled is true.
|
||||
|
||||
Leaves ``video_engine`` as None when disabled so existing voice flow
|
||||
is untouched. Later phases replace this stub with actual Wan2.2 +
|
||||
MuseTalk loading inside ``VideoEngine``.
|
||||
"""
|
||||
from server.config import config
|
||||
|
||||
video_cfg_raw = config.get("video", {}) or {}
|
||||
if not video_cfg_raw.get("enabled", False):
|
||||
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
|
||||
return
|
||||
|
||||
log.info("Loading avatar video engine...")
|
||||
cfg = VideoConfig.from_dict(video_cfg_raw)
|
||||
self.video_engine = VideoEngine(cfg)
|
||||
if cfg.loras:
|
||||
self.video_engine.load_loras(cfg.loras)
|
||||
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
|
||||
|
||||
def create_vad(self) -> StreamingVAD:
|
||||
"""Create a new StreamingVAD instance for a client session."""
|
||||
return StreamingVAD(self.vad_model)
|
||||
|
||||
+73
-14
@@ -157,11 +157,20 @@ class ConversationSession:
|
||||
|
||||
# TTS - stream chunks with per-sentence text
|
||||
await self.send_json({"type": "status", "state": "speaking"})
|
||||
|
||||
# Video-mode branch: if a video engine is loaded AND an avatar is
|
||||
# set, buffer the full TTS output into a single blob, run MuseTalk
|
||||
# lip-sync (library or reflective source), mux to MP4, and send the
|
||||
# full clip + text in one shot. The client plays the MP4 (which
|
||||
# carries audio) instead of the per-chunk PCM path.
|
||||
video_engine = getattr(self.models, "video_engine", None)
|
||||
use_video = video_engine is not None and video_engine.is_ready()
|
||||
|
||||
chunk_queue = queue.Queue()
|
||||
self._last_played_chunk_id = None
|
||||
|
||||
segments = _split_into_segments(response)
|
||||
log.info(f"TTS: split response into {len(segments)} segments")
|
||||
log.info(f"TTS: split response into {len(segments)} segments (video={use_video})")
|
||||
|
||||
def _tts_worker():
|
||||
try:
|
||||
@@ -187,6 +196,10 @@ class ConversationSession:
|
||||
chunk_id = 0
|
||||
# Maps chunk_id -> cumulative text up to and including that chunk
|
||||
chunk_text_map: dict[int, str] = {}
|
||||
# Video mode accumulator: we buffer all TTS audio into one float32
|
||||
# array so MuseTalk can align against the full utterance.
|
||||
audio_buffer: list[np.ndarray] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||
@@ -202,23 +215,69 @@ class ConversationSession:
|
||||
spoken_text += sentence_text
|
||||
chunk_text_map[chunk_id] = spoken_text
|
||||
|
||||
await self.send_json({
|
||||
"type": "response_text",
|
||||
"text": sentence_text,
|
||||
"chunk_id": chunk_id,
|
||||
"final": False,
|
||||
})
|
||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||
try:
|
||||
await self.send_bytes(pcm_bytes)
|
||||
except Exception:
|
||||
log.warning("Failed to send audio, client disconnected.")
|
||||
self.cancel_event.set()
|
||||
break
|
||||
if use_video:
|
||||
audio_buffer.append(audio)
|
||||
# Don't stream text or PCM during video mode — we'll send
|
||||
# everything after the clip renders so the client doesn't
|
||||
# start displaying text before the video is ready.
|
||||
else:
|
||||
await self.send_json({
|
||||
"type": "response_text",
|
||||
"text": sentence_text,
|
||||
"chunk_id": chunk_id,
|
||||
"final": False,
|
||||
})
|
||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||
try:
|
||||
await self.send_bytes(pcm_bytes)
|
||||
except Exception:
|
||||
log.warning("Failed to send audio, client disconnected.")
|
||||
self.cancel_event.set()
|
||||
break
|
||||
chunk_id += 1
|
||||
|
||||
tts_thread.join(timeout=2.0)
|
||||
|
||||
# Video mode: render the speaking clip now that TTS is done.
|
||||
if use_video and audio_buffer and not self.cancel_event.is_set():
|
||||
try:
|
||||
full_audio = np.concatenate(audio_buffer).astype(np.float32)
|
||||
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
|
||||
log.info(
|
||||
"Video: rendering speaking clip (audio=%ds, mode=%s)",
|
||||
int(len(full_audio) / sample_rate), video_engine.cfg.mode,
|
||||
)
|
||||
mp4_bytes = await asyncio.to_thread(
|
||||
video_engine.generate_speaking_clip,
|
||||
full_audio,
|
||||
sample_rate,
|
||||
response,
|
||||
)
|
||||
if self.cancel_event.is_set():
|
||||
log.info("Video clip discarded (cancelled during render).")
|
||||
else:
|
||||
duration_ms = int(len(full_audio) / sample_rate * 1000)
|
||||
await self.send_json({
|
||||
"type": "speaking_clip",
|
||||
"chunk_id": 0,
|
||||
"duration_ms": duration_ms,
|
||||
"text": response,
|
||||
"size_bytes": len(mp4_bytes),
|
||||
})
|
||||
await self.send_bytes(mp4_bytes)
|
||||
except Exception:
|
||||
log.exception("Video speaking-clip render failed; falling back silently.")
|
||||
# Best-effort: tell the client nothing was spoken visually.
|
||||
try:
|
||||
await self.send_json({
|
||||
"type": "response_text",
|
||||
"text": response,
|
||||
"chunk_id": 0,
|
||||
"final": True,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine what was actually heard by the client
|
||||
was_interrupted = spoken_text.strip() != response.strip()
|
||||
if was_interrupted and self._last_played_chunk_id is not None:
|
||||
|
||||
+391
@@ -0,0 +1,391 @@
|
||||
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
|
||||
|
||||
Top-level orchestrator. The heavy 3rd-party model code is isolated in
|
||||
``server/video_models/`` so each wrapper can be updated independently.
|
||||
|
||||
This module is only imported by ``server/models.py`` when
|
||||
``config.video.enabled`` is true. When disabled, the existing voice pipeline
|
||||
is completely untouched.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LoRATarget = Literal["high_noise", "low_noise", "both"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRASpec:
|
||||
"""One LoRA adapter entry from ``config.video.loras``.
|
||||
|
||||
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
|
||||
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
|
||||
sub-model) and must be applied to the correct target. Allow
|
||||
``target="both"`` for LoRAs that should be applied to both sub-models
|
||||
(e.g. style LoRAs).
|
||||
"""
|
||||
|
||||
path: str
|
||||
weight: float = 1.0
|
||||
target: LoRATarget = "both"
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoConfig:
|
||||
"""Flattened view of the ``video:`` section of config.yml."""
|
||||
|
||||
enabled: bool = False
|
||||
backend: str = "lightx2v"
|
||||
mode: str = "reflective" # "library" | "reflective"
|
||||
resolution: int = 480
|
||||
fps: int = 16
|
||||
library_base_clip_count: int = 4
|
||||
library_base_clip_seconds: int = 6
|
||||
reflective_clip_seconds: int = 5
|
||||
reflective_prompt_template: str = (
|
||||
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
||||
"natural lighting, soft focus background"
|
||||
)
|
||||
reflective_prompt_reply_words: int = 18
|
||||
loras: list[LoRASpec] = field(default_factory=list)
|
||||
|
||||
# Model paths — can be overridden via config.yml.video.models.
|
||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||
# The bf16 DIT shards in this repo are skipped — we
|
||||
# replace them with the fp8 files from wan22_fp8_repo.
|
||||
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
|
||||
# 4-step distilled DIT checkpoints (~15 GB each).
|
||||
# wan22_config_json: path to the LightX2V inference config template the
|
||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
||||
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
wan22_model_cls: str = "wan2.2_moe_distill"
|
||||
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, raw: dict) -> "VideoConfig":
|
||||
raw = raw or {}
|
||||
library = raw.get("library", {}) or {}
|
||||
reflective = raw.get("reflective", {}) or {}
|
||||
models_raw = raw.get("models", {}) or {}
|
||||
loras_raw = raw.get("loras") or []
|
||||
|
||||
default_template = (
|
||||
"webcam view of a person speaking, {reply_hint}, casual gestures, "
|
||||
"natural lighting, soft focus background"
|
||||
)
|
||||
|
||||
loras: list[LoRASpec] = []
|
||||
for entry in loras_raw:
|
||||
if not entry or "path" not in entry:
|
||||
continue
|
||||
target = str(entry.get("target", "both")).lower()
|
||||
if target not in ("high_noise", "low_noise", "both"):
|
||||
log.warning(
|
||||
"LoRA %s: invalid target %r, defaulting to 'both'",
|
||||
entry.get("path"), target,
|
||||
)
|
||||
target = "both"
|
||||
loras.append(
|
||||
LoRASpec(
|
||||
path=str(entry["path"]),
|
||||
weight=float(entry.get("weight", 1.0)),
|
||||
target=target, # type: ignore[arg-type]
|
||||
name=entry.get("name"),
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
enabled=bool(raw.get("enabled", False)),
|
||||
backend=str(raw.get("backend", "lightx2v")),
|
||||
mode=str(raw.get("mode", "reflective")),
|
||||
resolution=int(raw.get("resolution", 480)),
|
||||
fps=int(raw.get("fps", 16)),
|
||||
library_base_clip_count=int(library.get("base_clip_count", 4)),
|
||||
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
|
||||
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
|
||||
reflective_prompt_template=str(
|
||||
reflective.get("clip_prompt_template", default_template)
|
||||
),
|
||||
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
|
||||
loras=loras,
|
||||
wan22_base_repo=str(
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
||||
),
|
||||
wan22_fp8_repo=str(
|
||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
|
||||
),
|
||||
wan22_config_json=str(
|
||||
models_raw.get(
|
||||
"wan22_config_json",
|
||||
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
|
||||
)
|
||||
),
|
||||
wan22_model_cls=str(
|
||||
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
|
||||
),
|
||||
musetalk_model_path=str(
|
||||
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
|
||||
# less repetitive when replayed. Kept module-level so tests can import them.
|
||||
LIBRARY_BASE_PROMPTS = [
|
||||
"webcam view of a person speaking, subtle head nods, casual expression, "
|
||||
"natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, slight smile, gentle hand gesture, "
|
||||
"natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, looking thoughtful, small head tilt, "
|
||||
"natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, engaged and attentive, minor shoulder "
|
||||
"movement, natural lighting, soft focus background",
|
||||
"webcam view of a person speaking, relaxed posture, blinking naturally, "
|
||||
"natural lighting, soft focus background",
|
||||
]
|
||||
|
||||
IDLE_PROMPT = (
|
||||
"webcam view of a person listening quietly, mouth closed, subtle "
|
||||
"breathing, occasional blinks, calm expression, natural lighting, "
|
||||
"soft focus background"
|
||||
)
|
||||
|
||||
|
||||
class VideoEngine:
|
||||
"""Top-level video generation orchestrator.
|
||||
|
||||
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
|
||||
pre-rendered clips. Exposed to ``ConversationSession`` via
|
||||
``ModelManager.video_engine``.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: VideoConfig):
|
||||
self.cfg = cfg
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Avatar state
|
||||
self.avatar_path: str | None = None
|
||||
self.idle_clip_mp4: bytes | None = None
|
||||
# Pre-baked speaking base clips for library mode. Each entry is a
|
||||
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
|
||||
self.speaking_base_frames: list[np.ndarray] = []
|
||||
# Round-robin pointer for picking a library clip per turn
|
||||
self._library_cursor = 0
|
||||
|
||||
# Model wrappers — instantiated lazily by ``load_models()`` so unit
|
||||
# tests can exercise VideoEngine without touching CUDA at all.
|
||||
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
|
||||
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
|
||||
|
||||
log.info(
|
||||
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
|
||||
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
|
||||
)
|
||||
|
||||
# --- Model loading --------------------------------------------------
|
||||
|
||||
def load_models(self) -> None:
|
||||
"""Instantiate the underlying model wrappers.
|
||||
|
||||
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
|
||||
without triggering Wan2.2's ~12-16GB VRAM allocation.
|
||||
"""
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
from server.video_models.musetalk import MuseTalkEngine
|
||||
|
||||
log.info(
|
||||
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
|
||||
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
|
||||
)
|
||||
self._wan22 = Wan22Pipeline(
|
||||
base_repo=self.cfg.wan22_base_repo,
|
||||
fp8_repo=self.cfg.wan22_fp8_repo,
|
||||
config_json=self.cfg.wan22_config_json,
|
||||
model_cls=self.cfg.wan22_model_cls,
|
||||
resolution=self.cfg.resolution,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
if self.cfg.loras:
|
||||
self._wan22.load_loras(self.cfg.loras)
|
||||
log.info("Wan2.2 pipeline ready.")
|
||||
|
||||
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
|
||||
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
|
||||
log.info("MuseTalk engine ready.")
|
||||
|
||||
# --- Readiness ------------------------------------------------------
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True when an avatar is set and a speaking clip can be produced."""
|
||||
return (
|
||||
self._wan22 is not None
|
||||
and self._musetalk is not None
|
||||
and self.avatar_path is not None
|
||||
and self.idle_clip_mp4 is not None
|
||||
)
|
||||
|
||||
# --- LoRA management ------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list[LoRASpec]) -> None:
|
||||
"""Apply a list of LoRA adapters to the Wan2.2 base.
|
||||
|
||||
Replaces any previously applied LoRAs. Safe to call after init for
|
||||
hot-reload via ``POST /api/reload-loras``.
|
||||
"""
|
||||
if self._wan22 is None:
|
||||
raise RuntimeError("load_loras called before load_models()")
|
||||
with self._lock:
|
||||
self._wan22.unload_loras()
|
||||
self._wan22.load_loras(specs)
|
||||
self.cfg.loras = list(specs)
|
||||
log.info("Applied %d LoRA(s): %s",
|
||||
len(specs),
|
||||
", ".join(s.name or s.path for s in specs) or "<none>")
|
||||
|
||||
# --- Avatar lifecycle ----------------------------------------------
|
||||
|
||||
def set_avatar(self, image_path: str) -> None:
|
||||
"""Register an avatar image and pre-generate cached clips.
|
||||
|
||||
- Always: generate the idle loop.
|
||||
- Library mode: also pre-generate ``library.base_clip_count``
|
||||
speaking base clips.
|
||||
- Reflective mode: idle loop only.
|
||||
"""
|
||||
if self._wan22 is None:
|
||||
raise RuntimeError("set_avatar called before load_models()")
|
||||
|
||||
with self._lock:
|
||||
log.info("Setting avatar: %s", image_path)
|
||||
self.avatar_path = image_path
|
||||
# Drop any previously cached clips so the new avatar's library
|
||||
# doesn't mix with the old.
|
||||
self.speaking_base_frames = []
|
||||
self.idle_clip_mp4 = None
|
||||
|
||||
# Idle clip: short loop, neutral/listening prompt.
|
||||
log.info("Generating idle clip...")
|
||||
idle_frames = self._wan22.generate_i2v(
|
||||
image_path=image_path,
|
||||
prompt=IDLE_PROMPT,
|
||||
seconds=self.cfg.library_base_clip_seconds,
|
||||
seed=0,
|
||||
)
|
||||
from server.video_models.muxer import frames_to_mp4_loop
|
||||
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
|
||||
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
|
||||
|
||||
# Library mode: pre-bake N speaking base clips.
|
||||
if self.cfg.mode == "library":
|
||||
n = self.cfg.library_base_clip_count
|
||||
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
|
||||
for i in range(n):
|
||||
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
|
||||
frames = self._wan22.generate_i2v(
|
||||
image_path=image_path,
|
||||
prompt=prompt,
|
||||
seconds=self.cfg.library_base_clip_seconds,
|
||||
seed=i + 1,
|
||||
)
|
||||
self.speaking_base_frames.append(frames)
|
||||
log.info(" base clip %d/%d rendered", i + 1, n)
|
||||
|
||||
self._library_cursor = 0
|
||||
|
||||
def get_idle_clip(self) -> bytes | None:
|
||||
return self.idle_clip_mp4
|
||||
|
||||
# --- Per-turn generation -------------------------------------------
|
||||
|
||||
def generate_speaking_clip(
|
||||
self,
|
||||
audio_f32: np.ndarray,
|
||||
sample_rate: int,
|
||||
reply_text: str,
|
||||
) -> bytes:
|
||||
"""Produce a lip-synced MP4 for one assistant turn."""
|
||||
if not self.is_ready():
|
||||
raise RuntimeError(
|
||||
"generate_speaking_clip: engine not ready "
|
||||
"(avatar set? models loaded?)"
|
||||
)
|
||||
assert self._wan22 is not None
|
||||
assert self._musetalk is not None
|
||||
|
||||
# 1. Source base frames.
|
||||
if self.cfg.mode == "library":
|
||||
base_frames = self._pick_library_frames(audio_f32, sample_rate)
|
||||
else: # reflective
|
||||
prompt = self._derive_prompt(reply_text)
|
||||
log.info("Reflective prompt: %s", prompt[:120])
|
||||
base_frames = self._wan22.generate_i2v(
|
||||
image_path=self.avatar_path or "",
|
||||
prompt=prompt,
|
||||
seconds=self.cfg.reflective_clip_seconds,
|
||||
seed=None, # random each turn
|
||||
)
|
||||
|
||||
# 2. Lip-sync the base frames to the given audio.
|
||||
synced_frames = self._musetalk.lip_sync(
|
||||
frames=base_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
|
||||
# 3. Mux frames + audio into an MP4.
|
||||
from server.video_models.muxer import frames_and_audio_to_mp4
|
||||
return frames_and_audio_to_mp4(
|
||||
frames=synced_frames,
|
||||
audio=audio_f32,
|
||||
sample_rate=sample_rate,
|
||||
fps=self.cfg.fps,
|
||||
)
|
||||
|
||||
def _pick_library_frames(
|
||||
self, audio_f32: np.ndarray, sample_rate: int
|
||||
) -> np.ndarray:
|
||||
"""Round-robin pick from the pre-baked library, clipped or looped
|
||||
to roughly the audio's duration so there's no long freeze frame."""
|
||||
if not self.speaking_base_frames:
|
||||
raise RuntimeError(
|
||||
"Library mode has no pre-baked base clips. "
|
||||
"Was set_avatar called with mode=library?"
|
||||
)
|
||||
frames = self.speaking_base_frames[
|
||||
self._library_cursor % len(self.speaking_base_frames)
|
||||
]
|
||||
self._library_cursor += 1
|
||||
|
||||
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
|
||||
if target_frames <= 0:
|
||||
return frames
|
||||
if target_frames <= len(frames):
|
||||
return frames[:target_frames]
|
||||
# Loop (with a mirror tail to soften the seam) to cover longer audio.
|
||||
loops = target_frames // len(frames) + 1
|
||||
extended = np.concatenate([frames] * loops, axis=0)
|
||||
return extended[:target_frames]
|
||||
|
||||
def _derive_prompt(self, reply_text: str) -> str:
|
||||
"""Template-based prompt builder for reflective mode.
|
||||
|
||||
Takes up to ``prompt_reply_words`` words from the start of the reply
|
||||
and interpolates them into the configured template. Cheap,
|
||||
deterministic, no extra LLM call.
|
||||
"""
|
||||
words = (reply_text or "").split()
|
||||
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
|
||||
if not hint:
|
||||
hint = "calm and friendly"
|
||||
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Thin wrappers around 3rd-party video generation models.
|
||||
|
||||
Each submodule isolates one external dependency so the real API surface
|
||||
can be updated in a single file without touching the pipeline.
|
||||
|
||||
Submodules:
|
||||
- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V
|
||||
- ``musetalk``: MuseTalk audio-driven lip-sync
|
||||
- ``muxer``: ffmpeg-based frame/audio → MP4 encoding
|
||||
"""
|
||||
@@ -0,0 +1,164 @@
|
||||
"""MuseTalk audio-driven lip-sync wrapper.
|
||||
|
||||
MuseTalk takes a sequence of face frames + driving audio and returns a new
|
||||
sequence of frames where the mouth region is animated to match the audio.
|
||||
|
||||
This module isolates MuseTalk's real API behind a single ``lip_sync()``
|
||||
method. MuseTalk's upstream Python surface varies between forks — if the
|
||||
real import path or call signature differs, update this file only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MuseTalkEngine:
|
||||
"""Thin wrapper over MuseTalk inference."""
|
||||
|
||||
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
|
||||
self.model_path = model_path
|
||||
|
||||
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
|
||||
# similar ``MuseTalkInfer`` class. Try the most common imports.
|
||||
self._infer = self._load_impl(model_path)
|
||||
log.info("MuseTalk engine loaded from %s", model_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_impl(model_path: str):
|
||||
"""Load the MuseTalk inference implementation.
|
||||
|
||||
If none of the known entry points work the error message points at
|
||||
this file so you know where to fix it.
|
||||
"""
|
||||
resolved = model_path
|
||||
if not os.path.isdir(model_path) and "/" in model_path:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
resolved = snapshot_download(repo_id=model_path)
|
||||
except Exception as e: # pragma: no cover
|
||||
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
||||
|
||||
# Try upstream MuseTalk repo layout.
|
||||
try:
|
||||
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
|
||||
return MuseTalkInference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
|
||||
return MuseTalkInfer(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk import Inference # type: ignore[import-not-found]
|
||||
return Inference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk is installed but no known Python entry point was found. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
|
||||
"to match the installed MuseTalk version."
|
||||
)
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
def lip_sync(
|
||||
self,
|
||||
frames: np.ndarray,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
fps: int,
|
||||
) -> np.ndarray:
|
||||
"""Return new frames with lip-sync applied to match ``audio``.
|
||||
|
||||
Args:
|
||||
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
|
||||
audio: float32 mono 1D audio.
|
||||
sample_rate: sample rate of ``audio``.
|
||||
fps: frame rate of ``frames``.
|
||||
|
||||
Returns:
|
||||
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
|
||||
to match audio duration at ``fps``.
|
||||
"""
|
||||
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
||||
)
|
||||
|
||||
# Normalise frame count to audio duration so the caller doesn't have
|
||||
# to do the arithmetic.
|
||||
target_t = int(round(len(audio) / sample_rate * fps))
|
||||
if target_t > 0 and len(frames) != target_t:
|
||||
frames = _fit_frames_to_length(frames, target_t)
|
||||
|
||||
# The real MuseTalk call signature varies. Most common is a method
|
||||
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
|
||||
for method_name in ("run", "infer", "lip_sync", "__call__"):
|
||||
method = getattr(self._infer, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
try:
|
||||
result = method(
|
||||
frames=frames,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
fps=fps,
|
||||
)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
# Try positional
|
||||
try:
|
||||
result = method(frames, audio, sample_rate, fps)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk wrapper could not find a working inference method. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
|
||||
)
|
||||
|
||||
|
||||
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
|
||||
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
|
||||
|
||||
Repeats with a ping-pong / boomerang tail so the seam between loops is
|
||||
less jarring than a hard cut back to frame 0.
|
||||
"""
|
||||
if target_t <= 0:
|
||||
return frames
|
||||
t = len(frames)
|
||||
if t == target_t:
|
||||
return frames
|
||||
if t > target_t:
|
||||
return frames[:target_t]
|
||||
# Extend via ping-pong looping
|
||||
extended = [frames]
|
||||
total = t
|
||||
flip = True
|
||||
while total < target_t:
|
||||
seg = frames[::-1] if flip else frames
|
||||
extended.append(seg)
|
||||
total += t
|
||||
flip = not flip
|
||||
return np.concatenate(extended, axis=0)[:target_t]
|
||||
|
||||
|
||||
def _ensure_uint8_rgb(arr) -> np.ndarray:
|
||||
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
|
||||
result = np.asarray(arr)
|
||||
if result.dtype != np.uint8:
|
||||
if result.dtype in (np.float32, np.float64):
|
||||
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
result = result.astype(np.uint8)
|
||||
if result.ndim == 3:
|
||||
result = result[None, ...]
|
||||
return result
|
||||
@@ -0,0 +1,146 @@
|
||||
"""ffmpeg-based frame + audio → MP4 muxing.
|
||||
|
||||
Uses the system ``ffmpeg`` binary already installed in the Dockerfile.
|
||||
No extra python dependencies beyond ``numpy``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ffmpeg_bin() -> str:
|
||||
bin_path = shutil.which("ffmpeg")
|
||||
if bin_path is None:
|
||||
raise RuntimeError(
|
||||
"ffmpeg binary not found on PATH. It should be installed by "
|
||||
"the Dockerfile (line 13). Ensure you're running inside the "
|
||||
"docker image or install ffmpeg locally."
|
||||
)
|
||||
return bin_path
|
||||
|
||||
|
||||
def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]:
|
||||
"""Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w)."""
|
||||
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
||||
)
|
||||
if frames.dtype != np.uint8:
|
||||
frames = frames.astype(np.uint8)
|
||||
with open(path, "wb") as f:
|
||||
f.write(frames.tobytes())
|
||||
_, h, w, _ = frames.shape
|
||||
return h, w
|
||||
|
||||
|
||||
def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None:
|
||||
"""Write a float32 mono audio array to a 16-bit PCM WAV at ``path``."""
|
||||
from scipy.io import wavfile # type: ignore[import-not-found]
|
||||
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
|
||||
int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
|
||||
wavfile.write(path, sample_rate, int16)
|
||||
|
||||
|
||||
def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes:
|
||||
"""Encode ``frames`` to a silent MP4 suitable for looping playback.
|
||||
|
||||
Used for the idle clip: no audio track, loopable on an HTMLMediaElement
|
||||
without audible seams.
|
||||
"""
|
||||
if frames.size == 0:
|
||||
raise ValueError("frames_to_mp4_loop: empty frames")
|
||||
|
||||
ffmpeg = _ffmpeg_bin()
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
raw_path = os.path.join(td, "frames.raw")
|
||||
out_path = os.path.join(td, "out.mp4")
|
||||
h, w = _write_raw_frames(frames, raw_path)
|
||||
|
||||
cmd = [
|
||||
ffmpeg, "-y",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "rgb24",
|
||||
"-s", f"{w}x{h}",
|
||||
"-r", str(fps),
|
||||
"-i", raw_path,
|
||||
"-an",
|
||||
"-c:v", "libx264",
|
||||
"-preset", "veryfast",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-movflags", "+faststart",
|
||||
out_path,
|
||||
]
|
||||
log.debug("muxer idle clip: %s", " ".join(cmd))
|
||||
_run_ffmpeg(cmd)
|
||||
with open(out_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def frames_and_audio_to_mp4(
|
||||
frames: np.ndarray,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
fps: int,
|
||||
) -> bytes:
|
||||
"""Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio.
|
||||
|
||||
Used for per-turn speaking clips.
|
||||
"""
|
||||
if frames.size == 0:
|
||||
raise ValueError("frames_and_audio_to_mp4: empty frames")
|
||||
if audio.size == 0:
|
||||
raise ValueError("frames_and_audio_to_mp4: empty audio")
|
||||
|
||||
ffmpeg = _ffmpeg_bin()
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
raw_path = os.path.join(td, "frames.raw")
|
||||
wav_path = os.path.join(td, "audio.wav")
|
||||
out_path = os.path.join(td, "out.mp4")
|
||||
|
||||
h, w = _write_raw_frames(frames, raw_path)
|
||||
_write_wav(audio, sample_rate, wav_path)
|
||||
|
||||
cmd = [
|
||||
ffmpeg, "-y",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "rgb24",
|
||||
"-s", f"{w}x{h}",
|
||||
"-r", str(fps),
|
||||
"-i", raw_path,
|
||||
"-i", wav_path,
|
||||
"-c:v", "libx264",
|
||||
"-preset", "veryfast",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-c:a", "aac",
|
||||
"-b:a", "128k",
|
||||
"-shortest",
|
||||
"-movflags", "+faststart",
|
||||
out_path,
|
||||
]
|
||||
log.debug("muxer speaking clip: %s", " ".join(cmd))
|
||||
_run_ffmpeg(cmd)
|
||||
with open(out_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _run_ffmpeg(cmd: list[str]) -> None:
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace"))
|
||||
raise
|
||||
if proc.returncode != 0: # pragma: no cover
|
||||
raise RuntimeError(f"ffmpeg returned {proc.returncode}")
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
|
||||
from lightx2v.utils.set_config import set_config
|
||||
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
|
||||
from lightx2v.infer import init_runner
|
||||
|
||||
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
|
||||
config = set_config(args)
|
||||
input_info = init_empty_input_info(args.task, args.support_tasks)
|
||||
runner = init_runner(config) # loads all weights — done ONCE
|
||||
|
||||
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
|
||||
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
|
||||
# LoRA hot-swap:
|
||||
runner.switch_lora(lora_path, strength) # swap in
|
||||
runner.switch_lora("", 0.0) # remove
|
||||
|
||||
Model weights are loaded once at construction and held resident across turns
|
||||
so reflective mode doesn't re-pay the load cost each reply.
|
||||
|
||||
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with fp8.
|
||||
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
|
||||
the fp8 e4m3 4-step distilled high/low
|
||||
noise DIT checkpoints (~15 GB each).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from server.video import LoRASpec
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
|
||||
# files from the distill repo replace the DIT weights entirely. We must keep
|
||||
# the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
"high_noise_model/*.safetensors",
|
||||
"low_noise_model/*.safetensors",
|
||||
"assets/*",
|
||||
"examples/*",
|
||||
"nohup.out",
|
||||
"*.md",
|
||||
]
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
``generate_i2v`` runs one inference turn against the already-loaded runner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_repo: str,
|
||||
fp8_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.fp8_repo = fp8_repo
|
||||
self.config_json_template = config_json
|
||||
self.model_cls = model_cls
|
||||
self.resolution = resolution
|
||||
self.fps = fps
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
|
||||
|
||||
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
||||
self._runtime_json_path = self._build_runtime_config()
|
||||
|
||||
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
|
||||
args = self._build_args(
|
||||
model_cls=model_cls,
|
||||
model_path=self._model_root,
|
||||
config_json=self._runtime_json_path,
|
||||
)
|
||||
|
||||
# 4. set_config → init_runner. Runner construction triggers weight load.
|
||||
# Imports are scoped here so ``import server.video_models.wan22``
|
||||
# never pulls in lightx2v (tests can import this module on CPU).
|
||||
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
||||
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
||||
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
||||
|
||||
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
||||
model_cls, self._model_root)
|
||||
self._config = set_config(args)
|
||||
|
||||
self._input_info_template = init_empty_input_info(
|
||||
args.task, args.support_tasks
|
||||
)
|
||||
|
||||
log.info("LightX2V init_runner — loading weights (this takes a while)...")
|
||||
self._runner = init_runner(self._config)
|
||||
log.info("LightX2V runner loaded; weights resident.")
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _ensure_base_repo(base_repo: str) -> str:
|
||||
"""Return a local directory containing the Wan2.2 base support files.
|
||||
|
||||
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
||||
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
||||
shards (they're replaced by the fp8 files).
|
||||
"""
|
||||
if os.path.isdir(base_repo):
|
||||
return base_repo
|
||||
from huggingface_hub import snapshot_download
|
||||
log.info("Downloading Wan2.2 base support files from %s "
|
||||
"(skipping bf16 DIT shards)...", base_repo)
|
||||
return snapshot_download(
|
||||
repo_id=base_repo,
|
||||
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
|
||||
|
||||
- If ``fp8_repo`` is a local directory, expect both files inside it.
|
||||
- Otherwise treat it as a HF repo id and download only the two files
|
||||
we need (not the ~150 GB of other variants in that repo).
|
||||
"""
|
||||
if not fp8_repo:
|
||||
raise ValueError("fp8_repo must be a HF repo id or local directory.")
|
||||
if os.path.isdir(fp8_repo):
|
||||
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
|
||||
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
raise FileNotFoundError(
|
||||
f"fp8 checkpoints not found in {fp8_repo}: expected "
|
||||
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
|
||||
)
|
||||
return high, low
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
|
||||
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
|
||||
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
|
||||
return high, low
|
||||
|
||||
def _build_runtime_config(self) -> str:
|
||||
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
||||
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._fp8_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._fp8_low
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
prefix="wan22_fp8_", suffix=".json",
|
||||
mode="w", delete=False, encoding="utf-8",
|
||||
)
|
||||
json.dump(cfg, tmp, indent=2)
|
||||
tmp.close()
|
||||
log.info("Runtime LightX2V config: %s", tmp.name)
|
||||
return tmp.name
|
||||
|
||||
@staticmethod
|
||||
def _build_args(
|
||||
*, model_cls: str, model_path: str, config_json: str
|
||||
) -> argparse.Namespace:
|
||||
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
|
||||
``set_config`` finds the attributes it expects. We only customize the
|
||||
model/task/path fields; everything else stays at the CLI defaults.
|
||||
"""
|
||||
return argparse.Namespace(
|
||||
seed=42,
|
||||
model_cls=model_cls,
|
||||
task="i2v",
|
||||
support_tasks=[],
|
||||
model_path=model_path,
|
||||
sf_model_path=None,
|
||||
config_json=config_json,
|
||||
use_prompt_enhancer=False,
|
||||
prompt="",
|
||||
negative_prompt="",
|
||||
image_path="",
|
||||
last_frame_path="",
|
||||
audio_path="",
|
||||
image_strength="1.0",
|
||||
image_frame_idx="",
|
||||
src_ref_images=None,
|
||||
src_video=None,
|
||||
src_mask=None,
|
||||
src_pose_path=None,
|
||||
src_face_path=None,
|
||||
src_bg_path=None,
|
||||
src_mask_path=None,
|
||||
pose=None,
|
||||
action_path=None,
|
||||
action_ckpt=None,
|
||||
save_result_path=None,
|
||||
return_result_tensor=False,
|
||||
target_shape=[],
|
||||
target_video_length=81,
|
||||
aspect_ratio="",
|
||||
video_path=None,
|
||||
sr_ratio=2.0,
|
||||
)
|
||||
|
||||
# --- LoRA --------------------------------------------------------------
|
||||
|
||||
def load_loras(self, specs: list["LoRASpec"]) -> None:
|
||||
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
|
||||
|
||||
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
|
||||
to route the LoRA to the correct expert.
|
||||
|
||||
With ``lazy_load`` the DIT models are ``None`` at this point, so
|
||||
runtime ``switch_lora`` is impossible. Instead we inject
|
||||
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
|
||||
the LoRAs are applied when the models materialise on first inference.
|
||||
|
||||
Without ``lazy_load`` (models already resident) we call
|
||||
``switch_lora`` with explicit high/low keyword args.
|
||||
"""
|
||||
if not specs:
|
||||
return
|
||||
|
||||
# Resolve every path up-front (may trigger HF download).
|
||||
resolved: list[tuple["LoRASpec", str]] = []
|
||||
for spec in specs:
|
||||
local_path = self._resolve_lora_path(spec.path)
|
||||
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
|
||||
spec.name or spec.path, spec.weight, spec.target,
|
||||
local_path)
|
||||
resolved.append((spec, local_path))
|
||||
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
# Build the lora_configs list that LightX2V's lazy-load path
|
||||
# reads inside MultiDistillModelStruct.infer().
|
||||
lora_cfgs = []
|
||||
for spec, local_path in resolved:
|
||||
# LightX2V expects name "high_noise_model" / "low_noise_model"
|
||||
cfg_name = {
|
||||
"high_noise": "high_noise_model",
|
||||
"low_noise": "low_noise_model",
|
||||
}.get(spec.target)
|
||||
if cfg_name is None:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
lora_cfgs.append({
|
||||
"name": cfg_name,
|
||||
"path": local_path,
|
||||
"strength": spec.weight,
|
||||
})
|
||||
self._runner.set_config({
|
||||
"lora_configs": lora_cfgs,
|
||||
"lora_dynamic_apply": True,
|
||||
})
|
||||
else:
|
||||
# Models are loaded — use runtime hot-swap.
|
||||
high_path = high_strength = None
|
||||
low_path = low_strength = None
|
||||
for spec, local_path in resolved:
|
||||
if spec.target == "high_noise":
|
||||
high_path, high_strength = local_path, spec.weight
|
||||
elif spec.target == "low_noise":
|
||||
low_path, low_strength = local_path, spec.weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LoRA target must be 'high_noise' or 'low_noise', "
|
||||
f"got {spec.target!r}")
|
||||
|
||||
kwargs: dict = {}
|
||||
if high_path is not None:
|
||||
kwargs["high_lora_path"] = high_path
|
||||
kwargs["high_lora_strength"] = high_strength
|
||||
if low_path is not None:
|
||||
kwargs["low_lora_path"] = low_path
|
||||
kwargs["low_lora_strength"] = low_strength
|
||||
ok = self._runner.switch_lora(**kwargs)
|
||||
if not ok:
|
||||
raise RuntimeError(
|
||||
"runner.switch_lora returned False. Check that your "
|
||||
"LightX2V build supports runtime LoRA updates for "
|
||||
f"{self.model_cls}.")
|
||||
|
||||
self._applied_loras = list(specs)
|
||||
|
||||
def unload_loras(self) -> None:
|
||||
"""Remove all currently applied LoRAs."""
|
||||
if not self._applied_loras:
|
||||
return
|
||||
lazy = self._config.get("lazy_load", False)
|
||||
if lazy:
|
||||
self._runner.set_config({
|
||||
"lora_configs": None,
|
||||
"lora_dynamic_apply": False,
|
||||
})
|
||||
# If models were materialised, drop them so the next inference
|
||||
# recreates them without LoRAs.
|
||||
model_struct = getattr(self._runner, "model", None)
|
||||
if model_struct is not None and hasattr(model_struct, "model"):
|
||||
for i in range(len(model_struct.model)):
|
||||
model_struct.model[i] = None
|
||||
else:
|
||||
self._runner.switch_lora("", 0.0)
|
||||
self._applied_loras = []
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_path(path: str) -> str:
|
||||
"""Resolve a LoRA path. Supports:
|
||||
- Absolute/relative local paths (returned as-is if the file exists)
|
||||
- ``repo_id:filename`` HuggingFace references
|
||||
"""
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
if ":" in path and not path.startswith(("/", "./")):
|
||||
repo_id, filename = path.split(":", 1)
|
||||
from huggingface_hub import hf_hub_download
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
return path
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
def generate_i2v(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
seconds: int,
|
||||
seed: int | None = None,
|
||||
negative_prompt: str = "",
|
||||
) -> np.ndarray:
|
||||
"""Run image-to-video inference and return decoded frames.
|
||||
|
||||
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**31 - 1)
|
||||
|
||||
# Wan2.2 target_video_length is "frames including the conditioning
|
||||
# frame", so N seconds → N*fps + 1.
|
||||
target_frames = seconds * self.fps + 1
|
||||
|
||||
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||
out_path = tf.name
|
||||
try:
|
||||
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s",
|
||||
prompt[:80], seconds, seed, out_path)
|
||||
update_input_info_from_dict(
|
||||
self._input_info_template,
|
||||
{
|
||||
"seed": seed,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image_path": image_path,
|
||||
"save_result_path": out_path,
|
||||
"target_video_length": target_frames,
|
||||
"return_result_tensor": False,
|
||||
},
|
||||
)
|
||||
self._runner.run_pipeline(self._input_info_template)
|
||||
return _read_mp4_to_frames(out_path)
|
||||
finally:
|
||||
try:
|
||||
os.remove(out_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# --- MP4 decoding helper ------------------------------------------------------
|
||||
|
||||
def _read_mp4_to_frames(path: str) -> np.ndarray:
|
||||
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
|
||||
try:
|
||||
import imageio.v3 as iio # type: ignore[import-not-found]
|
||||
frames = iio.imread(path, plugin="pyav")
|
||||
arr = np.asarray(frames)
|
||||
if arr.ndim == 3:
|
||||
arr = arr[None, ...]
|
||||
return arr.astype(np.uint8)
|
||||
except Exception as e: # pragma: no cover - fallback path
|
||||
log.warning("imageio decode failed (%s); falling back to cv2", e)
|
||||
import cv2 # type: ignore[import-not-found]
|
||||
cap = cv2.VideoCapture(path)
|
||||
frames: list[np.ndarray] = []
|
||||
while True:
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
break
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
cap.release()
|
||||
if not frames:
|
||||
raise RuntimeError(f"Failed to decode any frames from {path}")
|
||||
return np.stack(frames, axis=0).astype(np.uint8)
|
||||
Reference in New Issue
Block a user