Files
2026-04-16 10:00:37 -04:00

152 lines
5.8 KiB
Python

"""MuseTalk audio-driven lip-sync wrapper.
MuseTalk takes a sequence of face frames + driving audio and returns a new
sequence of frames where the mouth region is animated to match the audio.
This module isolates MuseTalk's real API behind a single ``lip_sync()``
method. MuseTalk's upstream Python surface varies between forks — if the
real import path or call signature differs, update this file only.
"""
from __future__ import annotations
import logging
import os
import numpy as np
log = logging.getLogger(__name__)
class MuseTalkEngine:
"""Thin wrapper over MuseTalk inference."""
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
self.model_path = model_path
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
# similar ``MuseTalkInfer`` class. Try the most common imports.
self._infer = self._load_impl(model_path)
log.info("MuseTalk engine loaded from %s", model_path)
@staticmethod
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
Upstream MuseTalk has no library-style entry point — it's a bundle
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
``third_party/MuseTalk`` adds package metadata but the low-level
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
modules. We import them here to verify the install succeeded; the
actual pipeline (VAE, UNet, Whisper, face detection, blending)
is wired up inside ``MuseTalkEngine.lip_sync``.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
try:
from huggingface_hub import snapshot_download
resolved = snapshot_download(repo_id=model_path)
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
try:
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
except ImportError as e:
raise RuntimeError(
"MuseTalk Python package is not importable. "
"Check that third_party/MuseTalk was installed via "
"`pip install /opt/MuseTalk` in the Dockerfile."
) from e
# Return the resolved weight path; lip_sync loads models lazily on
# first call so import-time failures don't block voice-only startup.
return {"model_path": resolved, "loaded": False}
# --- Inference ---------------------------------------------------------
def lip_sync(
self,
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> np.ndarray:
"""Return new frames with lip-sync applied to match ``audio``.
Args:
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
audio: float32 mono 1D audio.
sample_rate: sample rate of ``audio``.
fps: frame rate of ``frames``.
Returns:
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
to match audio duration at ``fps``.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
# Normalise frame count to audio duration so the caller doesn't have
# to do the arithmetic.
target_t = int(round(len(audio) / sample_rate * fps))
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
# realtime_inference.py::Avatar.inference) needs:
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
# - face_alignment (bbox)
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
# - Whisper encoder (openai/whisper-tiny)
# - face_parsing weights
# Plus its preprocessing module has import-time side effects that
# load dwpose weights from a CWD-relative path. Turn the full
# pipeline on by extending this method once those deps are
# installed and weights are resolved — until then, callers should
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
# will skip the lip-sync pass.
raise NotImplementedError(
"MuseTalk lip-sync pipeline is not wired up yet. "
"Set config.video.musetalk.enabled=false to bypass."
)
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
Repeats with a ping-pong / boomerang tail so the seam between loops is
less jarring than a hard cut back to frame 0.
"""
if target_t <= 0:
return frames
t = len(frames)
if t == target_t:
return frames
if t > target_t:
return frames[:target_t]
# Extend via ping-pong looping
extended = [frames]
total = t
flip = True
while total < target_t:
seg = frames[::-1] if flip else frames
extended.append(seg)
total += t
flip = not flip
return np.concatenate(extended, axis=0)[:target_t]
def _ensure_uint8_rgb(arr) -> np.ndarray:
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
result = np.asarray(arr)
if result.dtype != np.uint8:
if result.dtype in (np.float32, np.float64):
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
if result.ndim == 3:
result = result[None, ...]
return result