first stab at adding video

This commit is contained in:
2026-04-12 04:11:52 -04:00
parent 680c5b04cc
commit 2818b41004
37 changed files with 2982 additions and 24 deletions
+10
View File
@@ -0,0 +1,10 @@
"""Thin wrappers around 3rd-party video generation models.
Each submodule isolates one external dependency so the real API surface
can be updated in a single file without touching the pipeline.
Submodules:
- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V
- ``musetalk``: MuseTalk audio-driven lip-sync
- ``muxer``: ffmpeg-based frame/audio → MP4 encoding
"""
+164
View File
@@ -0,0 +1,164 @@
"""MuseTalk audio-driven lip-sync wrapper.
MuseTalk takes a sequence of face frames + driving audio and returns a new
sequence of frames where the mouth region is animated to match the audio.
This module isolates MuseTalk's real API behind a single ``lip_sync()``
method. MuseTalk's upstream Python surface varies between forks — if the
real import path or call signature differs, update this file only.
"""
from __future__ import annotations
import logging
import os
import numpy as np
log = logging.getLogger(__name__)
class MuseTalkEngine:
"""Thin wrapper over MuseTalk inference."""
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
self.model_path = model_path
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
# similar ``MuseTalkInfer`` class. Try the most common imports.
self._infer = self._load_impl(model_path)
log.info("MuseTalk engine loaded from %s", model_path)
@staticmethod
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
If none of the known entry points work the error message points at
this file so you know where to fix it.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
try:
from huggingface_hub import snapshot_download
resolved = snapshot_download(repo_id=model_path)
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
# Try upstream MuseTalk repo layout.
try:
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
return MuseTalkInference(model_path=resolved)
except ImportError:
pass
try:
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
return MuseTalkInfer(model_path=resolved)
except ImportError:
pass
try:
from musetalk import Inference # type: ignore[import-not-found]
return Inference(model_path=resolved)
except ImportError:
pass
raise RuntimeError(
"MuseTalk is installed but no known Python entry point was found. "
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
"to match the installed MuseTalk version."
)
# --- Inference ---------------------------------------------------------
def lip_sync(
self,
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> np.ndarray:
"""Return new frames with lip-sync applied to match ``audio``.
Args:
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
audio: float32 mono 1D audio.
sample_rate: sample rate of ``audio``.
fps: frame rate of ``frames``.
Returns:
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
to match audio duration at ``fps``.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
# Normalise frame count to audio duration so the caller doesn't have
# to do the arithmetic.
target_t = int(round(len(audio) / sample_rate * fps))
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# The real MuseTalk call signature varies. Most common is a method
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
for method_name in ("run", "infer", "lip_sync", "__call__"):
method = getattr(self._infer, method_name, None)
if method is None:
continue
try:
result = method(
frames=frames,
audio=audio,
sample_rate=sample_rate,
fps=fps,
)
return _ensure_uint8_rgb(result)
except TypeError:
# Try positional
try:
result = method(frames, audio, sample_rate, fps)
return _ensure_uint8_rgb(result)
except TypeError:
continue
raise RuntimeError(
"MuseTalk wrapper could not find a working inference method. "
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
)
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
Repeats with a ping-pong / boomerang tail so the seam between loops is
less jarring than a hard cut back to frame 0.
"""
if target_t <= 0:
return frames
t = len(frames)
if t == target_t:
return frames
if t > target_t:
return frames[:target_t]
# Extend via ping-pong looping
extended = [frames]
total = t
flip = True
while total < target_t:
seg = frames[::-1] if flip else frames
extended.append(seg)
total += t
flip = not flip
return np.concatenate(extended, axis=0)[:target_t]
def _ensure_uint8_rgb(arr) -> np.ndarray:
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
result = np.asarray(arr)
if result.dtype != np.uint8:
if result.dtype in (np.float32, np.float64):
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
if result.ndim == 3:
result = result[None, ...]
return result
+146
View File
@@ -0,0 +1,146 @@
"""ffmpeg-based frame + audio → MP4 muxing.
Uses the system ``ffmpeg`` binary already installed in the Dockerfile.
No extra python dependencies beyond ``numpy``.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import tempfile
import numpy as np
log = logging.getLogger(__name__)
def _ffmpeg_bin() -> str:
bin_path = shutil.which("ffmpeg")
if bin_path is None:
raise RuntimeError(
"ffmpeg binary not found on PATH. It should be installed by "
"the Dockerfile (line 13). Ensure you're running inside the "
"docker image or install ffmpeg locally."
)
return bin_path
def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]:
"""Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w)."""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
if frames.dtype != np.uint8:
frames = frames.astype(np.uint8)
with open(path, "wb") as f:
f.write(frames.tobytes())
_, h, w, _ = frames.shape
return h, w
def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None:
"""Write a float32 mono audio array to a 16-bit PCM WAV at ``path``."""
from scipy.io import wavfile # type: ignore[import-not-found]
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
wavfile.write(path, sample_rate, int16)
def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes:
"""Encode ``frames`` to a silent MP4 suitable for looping playback.
Used for the idle clip: no audio track, loopable on an HTMLMediaElement
without audible seams.
"""
if frames.size == 0:
raise ValueError("frames_to_mp4_loop: empty frames")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-an",
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer idle clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def frames_and_audio_to_mp4(
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> bytes:
"""Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio.
Used for per-turn speaking clips.
"""
if frames.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty frames")
if audio.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty audio")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
wav_path = os.path.join(td, "audio.wav")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
_write_wav(audio, sample_rate, wav_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-i", wav_path,
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-c:a", "aac",
"-b:a", "128k",
"-shortest",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer speaking clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def _run_ffmpeg(cmd: list[str]) -> None:
try:
proc = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace"))
raise
if proc.returncode != 0: # pragma: no cover
raise RuntimeError(f"ffmpeg returned {proc.returncode}")
+423
View File
@@ -0,0 +1,423 @@
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
config = set_config(args)
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — done ONCE
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
# LoRA hot-swap:
runner.switch_lora(lora_path, strength) # swap in
runner.switch_lora("", 0.0) # remove
Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with fp8.
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
the fp8 e4m3 4-step distilled high/low
noise DIT checkpoints (~15 GB each).
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import tempfile
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from server.video import LoRASpec
log = logging.getLogger(__name__)
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
# files from the distill repo replace the DIT weights entirely. We must keep
# the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/.
BASE_REPO_IGNORE_PATTERNS = [
"high_noise_model/*.safetensors",
"low_noise_model/*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
"*.md",
]
class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
``generate_i2v`` runs one inference turn against the already-loaded runner.
"""
def __init__(
self,
base_repo: str,
fp8_repo: str,
config_json: str,
model_cls: str = "wan2.2_moe_distill",
resolution: int = 480,
fps: int = 16,
):
self.base_repo = base_repo
self.fp8_repo = fp8_repo
self.config_json_template = config_json
self.model_cls = model_cls
self.resolution = resolution
self.fps = fps
self._applied_loras: list[LoRASpec] = []
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
self._model_root = self._ensure_base_repo(base_repo)
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
# 2. Materialize a runtime JSON config with absolute ckpt paths.
self._runtime_json_path = self._build_runtime_config()
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
args = self._build_args(
model_cls=model_cls,
model_path=self._model_root,
config_json=self._runtime_json_path,
)
# 4. set_config → init_runner. Runner construction triggers weight load.
# Imports are scoped here so ``import server.video_models.wan22``
# never pulls in lightx2v (tests can import this module on CPU).
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
from lightx2v.infer import init_runner # type: ignore[import-not-found]
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
model_cls, self._model_root)
self._config = set_config(args)
self._input_info_template = init_empty_input_info(
args.task, args.support_tasks
)
log.info("LightX2V init_runner — loading weights (this takes a while)...")
self._runner = init_runner(self._config)
log.info("LightX2V runner loaded; weights resident.")
# --- Weight provisioning -------------------------------------------------
@staticmethod
def _ensure_base_repo(base_repo: str) -> str:
"""Return a local directory containing the Wan2.2 base support files.
If ``base_repo`` is already a local directory, use it as-is. Otherwise
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
shards (they're replaced by the fp8 files).
"""
if os.path.isdir(base_repo):
return base_repo
from huggingface_hub import snapshot_download
log.info("Downloading Wan2.2 base support files from %s "
"(skipping bf16 DIT shards)...", base_repo)
return snapshot_download(
repo_id=base_repo,
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
)
@staticmethod
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
- If ``fp8_repo`` is a local directory, expect both files inside it.
- Otherwise treat it as a HF repo id and download only the two files
we need (not the ~150 GB of other variants in that repo).
"""
if not fp8_repo:
raise ValueError("fp8_repo must be a HF repo id or local directory.")
if os.path.isdir(fp8_repo):
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
if not (os.path.isfile(high) and os.path.isfile(low)):
raise FileNotFoundError(
f"fp8 checkpoints not found in {fp8_repo}: expected "
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
)
return high, low
from huggingface_hub import hf_hub_download
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
return high, low
def _build_runtime_config(self) -> str:
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
with open(self.config_json_template, "r", encoding="utf-8") as f:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._fp8_high
cfg["low_noise_quantized_ckpt"] = self._fp8_low
cfg.setdefault("fps", self.fps)
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_fp8_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
log.info("Runtime LightX2V config: %s", tmp.name)
return tmp.name
@staticmethod
def _build_args(
*, model_cls: str, model_path: str, config_json: str
) -> argparse.Namespace:
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
``set_config`` finds the attributes it expects. We only customize the
model/task/path fields; everything else stays at the CLI defaults.
"""
return argparse.Namespace(
seed=42,
model_cls=model_cls,
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
to route the LoRA to the correct expert.
With ``lazy_load`` the DIT models are ``None`` at this point, so
runtime ``switch_lora`` is impossible. Instead we inject
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
the LoRAs are applied when the models materialise on first inference.
Without ``lazy_load`` (models already resident) we call
``switch_lora`` with explicit high/low keyword args.
"""
if not specs:
return
# Resolve every path up-front (may trigger HF download).
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
spec.name or spec.path, spec.weight, spec.target,
local_path)
resolved.append((spec, local_path))
lazy = self._config.get("lazy_load", False)
if lazy:
# Build the lora_configs list that LightX2V's lazy-load path
# reads inside MultiDistillModelStruct.infer().
lora_cfgs = []
for spec, local_path in resolved:
# LightX2V expects name "high_noise_model" / "low_noise_model"
cfg_name = {
"high_noise": "high_noise_model",
"low_noise": "low_noise_model",
}.get(spec.target)
if cfg_name is None:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
lora_cfgs.append({
"name": cfg_name,
"path": local_path,
"strength": spec.weight,
})
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
else:
# Models are loaded — use runtime hot-swap.
high_path = high_strength = None
low_path = low_strength = None
for spec, local_path in resolved:
if spec.target == "high_noise":
high_path, high_strength = local_path, spec.weight
elif spec.target == "low_noise":
low_path, low_strength = local_path, spec.weight
else:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
kwargs: dict = {}
if high_path is not None:
kwargs["high_lora_path"] = high_path
kwargs["high_lora_strength"] = high_strength
if low_path is not None:
kwargs["low_lora_path"] = low_path
kwargs["low_lora_strength"] = low_strength
ok = self._runner.switch_lora(**kwargs)
if not ok:
raise RuntimeError(
"runner.switch_lora returned False. Check that your "
"LightX2V build supports runtime LoRA updates for "
f"{self.model_cls}.")
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
lazy = self._config.get("lazy_load", False)
if lazy:
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
# If models were materialised, drop them so the next inference
# recreates them without LoRAs.
model_struct = getattr(self._runner, "model", None)
if model_struct is not None and hasattr(model_struct, "model"):
for i in range(len(model_struct.model)):
model_struct.model[i] = None
else:
self._runner.switch_lora("", 0.0)
self._applied_loras = []
@staticmethod
def _resolve_lora_path(path: str) -> str:
"""Resolve a LoRA path. Supports:
- Absolute/relative local paths (returned as-is if the file exists)
- ``repo_id:filename`` HuggingFace references
"""
if os.path.isfile(path):
return path
if ":" in path and not path.startswith(("/", "./")):
repo_id, filename = path.split(":", 1)
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename)
return path
# --- Inference ---------------------------------------------------------
def generate_i2v(
self,
image_path: str,
prompt: str,
seconds: int,
seed: int | None = None,
negative_prompt: str = "",
) -> np.ndarray:
"""Run image-to-video inference and return decoded frames.
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
"""
if seed is None:
seed = random.randint(0, 2**31 - 1)
# Wan2.2 target_video_length is "frames including the conditioning
# frame", so N seconds → N*fps + 1.
target_frames = seconds * self.fps + 1
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
try:
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d%s",
prompt[:80], seconds, seed, out_path)
update_input_info_from_dict(
self._input_info_template,
{
"seed": seed,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_path": image_path,
"save_result_path": out_path,
"target_video_length": target_frames,
"return_result_tensor": False,
},
)
self._runner.run_pipeline(self._input_info_template)
return _read_mp4_to_frames(out_path)
finally:
try:
os.remove(out_path)
except OSError:
pass
# --- MP4 decoding helper ------------------------------------------------------
def _read_mp4_to_frames(path: str) -> np.ndarray:
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
try:
import imageio.v3 as iio # type: ignore[import-not-found]
frames = iio.imread(path, plugin="pyav")
arr = np.asarray(frames)
if arr.ndim == 3:
arr = arr[None, ...]
return arr.astype(np.uint8)
except Exception as e: # pragma: no cover - fallback path
log.warning("imageio decode failed (%s); falling back to cv2", e)
import cv2 # type: ignore[import-not-found]
cap = cv2.VideoCapture(path)
frames: list[np.ndarray] = []
while True:
ok, frame = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if not frames:
raise RuntimeError(f"Failed to decode any frames from {path}")
return np.stack(frames, axis=0).astype(np.uint8)