first stab at adding video

This commit is contained in:
2026-04-12 04:11:52 -04:00
parent 680c5b04cc
commit 2818b41004
37 changed files with 2982 additions and 24 deletions
+164
View File
@@ -0,0 +1,164 @@
"""MuseTalk audio-driven lip-sync wrapper.
MuseTalk takes a sequence of face frames + driving audio and returns a new
sequence of frames where the mouth region is animated to match the audio.
This module isolates MuseTalk's real API behind a single ``lip_sync()``
method. MuseTalk's upstream Python surface varies between forks — if the
real import path or call signature differs, update this file only.
"""
from __future__ import annotations
import logging
import os
import numpy as np
log = logging.getLogger(__name__)
class MuseTalkEngine:
"""Thin wrapper over MuseTalk inference."""
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
self.model_path = model_path
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
# similar ``MuseTalkInfer`` class. Try the most common imports.
self._infer = self._load_impl(model_path)
log.info("MuseTalk engine loaded from %s", model_path)
@staticmethod
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
If none of the known entry points work the error message points at
this file so you know where to fix it.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
try:
from huggingface_hub import snapshot_download
resolved = snapshot_download(repo_id=model_path)
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
# Try upstream MuseTalk repo layout.
try:
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
return MuseTalkInference(model_path=resolved)
except ImportError:
pass
try:
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
return MuseTalkInfer(model_path=resolved)
except ImportError:
pass
try:
from musetalk import Inference # type: ignore[import-not-found]
return Inference(model_path=resolved)
except ImportError:
pass
raise RuntimeError(
"MuseTalk is installed but no known Python entry point was found. "
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
"to match the installed MuseTalk version."
)
# --- Inference ---------------------------------------------------------
def lip_sync(
self,
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> np.ndarray:
"""Return new frames with lip-sync applied to match ``audio``.
Args:
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
audio: float32 mono 1D audio.
sample_rate: sample rate of ``audio``.
fps: frame rate of ``frames``.
Returns:
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
to match audio duration at ``fps``.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
# Normalise frame count to audio duration so the caller doesn't have
# to do the arithmetic.
target_t = int(round(len(audio) / sample_rate * fps))
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# The real MuseTalk call signature varies. Most common is a method
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
for method_name in ("run", "infer", "lip_sync", "__call__"):
method = getattr(self._infer, method_name, None)
if method is None:
continue
try:
result = method(
frames=frames,
audio=audio,
sample_rate=sample_rate,
fps=fps,
)
return _ensure_uint8_rgb(result)
except TypeError:
# Try positional
try:
result = method(frames, audio, sample_rate, fps)
return _ensure_uint8_rgb(result)
except TypeError:
continue
raise RuntimeError(
"MuseTalk wrapper could not find a working inference method. "
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
)
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
Repeats with a ping-pong / boomerang tail so the seam between loops is
less jarring than a hard cut back to frame 0.
"""
if target_t <= 0:
return frames
t = len(frames)
if t == target_t:
return frames
if t > target_t:
return frames[:target_t]
# Extend via ping-pong looping
extended = [frames]
total = t
flip = True
while total < target_t:
seg = frames[::-1] if flip else frames
extended.append(seg)
total += t
flip = not flip
return np.concatenate(extended, axis=0)[:target_t]
def _ensure_uint8_rgb(arr) -> np.ndarray:
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
result = np.asarray(arr)
if result.dtype != np.uint8:
if result.dtype in (np.float32, np.float64):
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
if result.ndim == 3:
result = result[None, ...]
return result