152 lines
5.8 KiB
Python
152 lines
5.8 KiB
Python
"""MuseTalk audio-driven lip-sync wrapper.
|
|
|
|
MuseTalk takes a sequence of face frames + driving audio and returns a new
|
|
sequence of frames where the mouth region is animated to match the audio.
|
|
|
|
This module isolates MuseTalk's real API behind a single ``lip_sync()``
|
|
method. MuseTalk's upstream Python surface varies between forks — if the
|
|
real import path or call signature differs, update this file only.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MuseTalkEngine:
|
|
"""Thin wrapper over MuseTalk inference."""
|
|
|
|
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
|
|
self.model_path = model_path
|
|
|
|
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
|
|
# similar ``MuseTalkInfer`` class. Try the most common imports.
|
|
self._infer = self._load_impl(model_path)
|
|
log.info("MuseTalk engine loaded from %s", model_path)
|
|
|
|
@staticmethod
|
|
def _load_impl(model_path: str):
|
|
"""Load the MuseTalk inference implementation.
|
|
|
|
Upstream MuseTalk has no library-style entry point — it's a bundle
|
|
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
|
|
``third_party/MuseTalk`` adds package metadata but the low-level
|
|
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
|
|
modules. We import them here to verify the install succeeded; the
|
|
actual pipeline (VAE, UNet, Whisper, face detection, blending)
|
|
is wired up inside ``MuseTalkEngine.lip_sync``.
|
|
"""
|
|
resolved = model_path
|
|
if not os.path.isdir(model_path) and "/" in model_path:
|
|
try:
|
|
from huggingface_hub import snapshot_download
|
|
resolved = snapshot_download(repo_id=model_path)
|
|
except Exception as e: # pragma: no cover
|
|
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
|
|
|
try:
|
|
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
|
|
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"MuseTalk Python package is not importable. "
|
|
"Check that third_party/MuseTalk was installed via "
|
|
"`pip install /opt/MuseTalk` in the Dockerfile."
|
|
) from e
|
|
|
|
# Return the resolved weight path; lip_sync loads models lazily on
|
|
# first call so import-time failures don't block voice-only startup.
|
|
return {"model_path": resolved, "loaded": False}
|
|
|
|
# --- Inference ---------------------------------------------------------
|
|
|
|
def lip_sync(
|
|
self,
|
|
frames: np.ndarray,
|
|
audio: np.ndarray,
|
|
sample_rate: int,
|
|
fps: int,
|
|
) -> np.ndarray:
|
|
"""Return new frames with lip-sync applied to match ``audio``.
|
|
|
|
Args:
|
|
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
|
|
audio: float32 mono 1D audio.
|
|
sample_rate: sample rate of ``audio``.
|
|
fps: frame rate of ``frames``.
|
|
|
|
Returns:
|
|
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
|
|
to match audio duration at ``fps``.
|
|
"""
|
|
if frames.ndim != 4 or frames.shape[-1] != 3:
|
|
raise ValueError(
|
|
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
|
)
|
|
|
|
# Normalise frame count to audio duration so the caller doesn't have
|
|
# to do the arithmetic.
|
|
target_t = int(round(len(audio) / sample_rate * fps))
|
|
if target_t > 0 and len(frames) != target_t:
|
|
frames = _fit_frames_to_length(frames, target_t)
|
|
|
|
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
|
|
# realtime_inference.py::Avatar.inference) needs:
|
|
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
|
|
# - face_alignment (bbox)
|
|
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
|
|
# - Whisper encoder (openai/whisper-tiny)
|
|
# - face_parsing weights
|
|
# Plus its preprocessing module has import-time side effects that
|
|
# load dwpose weights from a CWD-relative path. Turn the full
|
|
# pipeline on by extending this method once those deps are
|
|
# installed and weights are resolved — until then, callers should
|
|
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
|
|
# will skip the lip-sync pass.
|
|
raise NotImplementedError(
|
|
"MuseTalk lip-sync pipeline is not wired up yet. "
|
|
"Set config.video.musetalk.enabled=false to bypass."
|
|
)
|
|
|
|
|
|
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
|
|
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
|
|
|
|
Repeats with a ping-pong / boomerang tail so the seam between loops is
|
|
less jarring than a hard cut back to frame 0.
|
|
"""
|
|
if target_t <= 0:
|
|
return frames
|
|
t = len(frames)
|
|
if t == target_t:
|
|
return frames
|
|
if t > target_t:
|
|
return frames[:target_t]
|
|
# Extend via ping-pong looping
|
|
extended = [frames]
|
|
total = t
|
|
flip = True
|
|
while total < target_t:
|
|
seg = frames[::-1] if flip else frames
|
|
extended.append(seg)
|
|
total += t
|
|
flip = not flip
|
|
return np.concatenate(extended, axis=0)[:target_t]
|
|
|
|
|
|
def _ensure_uint8_rgb(arr) -> np.ndarray:
|
|
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
|
|
result = np.asarray(arr)
|
|
if result.dtype != np.uint8:
|
|
if result.dtype in (np.float32, np.float64):
|
|
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
|
else:
|
|
result = result.astype(np.uint8)
|
|
if result.ndim == 3:
|
|
result = result[None, ...]
|
|
return result
|