"""MuseTalk audio-driven lip-sync wrapper. MuseTalk takes a sequence of face frames + driving audio and returns a new sequence of frames where the mouth region is animated to match the audio. This module isolates MuseTalk's real API behind a single ``lip_sync()`` method. MuseTalk's upstream Python surface varies between forks — if the real import path or call signature differs, update this file only. """ from __future__ import annotations import logging import os import numpy as np log = logging.getLogger(__name__) class MuseTalkEngine: """Thin wrapper over MuseTalk inference.""" def __init__(self, model_path: str = "TMElyralab/MuseTalk"): self.model_path = model_path # MuseTalk's canonical entry point is ``musetalk.inference`` or a # similar ``MuseTalkInfer`` class. Try the most common imports. self._infer = self._load_impl(model_path) log.info("MuseTalk engine loaded from %s", model_path) @staticmethod def _load_impl(model_path: str): """Load the MuseTalk inference implementation. Upstream MuseTalk has no library-style entry point — it's a bundle of training/inference CLI scripts. The bhetherman/MuseTalk fork at ``third_party/MuseTalk`` adds package metadata but the low-level API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*`` modules. We import them here to verify the install succeeded; the actual pipeline (VAE, UNet, Whisper, face detection, blending) is wired up inside ``MuseTalkEngine.lip_sync``. """ resolved = model_path if not os.path.isdir(model_path) and "/" in model_path: try: from huggingface_hub import snapshot_download resolved = snapshot_download(repo_id=model_path) except Exception as e: # pragma: no cover log.warning("Could not snapshot_download MuseTalk repo: %s", e) try: from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401 from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401 except ImportError as e: raise RuntimeError( "MuseTalk Python package is not importable. " "Check that third_party/MuseTalk was installed via " "`pip install /opt/MuseTalk` in the Dockerfile." ) from e # Return the resolved weight path; lip_sync loads models lazily on # first call so import-time failures don't block voice-only startup. return {"model_path": resolved, "loaded": False} # --- Inference --------------------------------------------------------- def lip_sync( self, frames: np.ndarray, audio: np.ndarray, sample_rate: int, fps: int, ) -> np.ndarray: """Return new frames with lip-sync applied to match ``audio``. Args: frames: uint8 ``[T, H, W, 3]`` RGB base frames. audio: float32 mono 1D audio. sample_rate: sample rate of ``audio``. fps: frame rate of ``frames``. Returns: uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded to match audio duration at ``fps``. """ if frames.ndim != 4 or frames.shape[-1] != 3: raise ValueError( f"frames must be [T, H, W, 3] uint8, got {frames.shape}" ) # Normalise frame count to audio duration so the caller doesn't have # to do the arithmetic. target_t = int(round(len(audio) / sample_rate * fps)) if target_t > 0 and len(frames) != target_t: frames = _fit_frames_to_length(frames, target_t) # MuseTalk's real inference path (see third_party/MuseTalk/scripts/ # realtime_inference.py::Avatar.inference) needs: # - mmpose + mmcv + mmengine (dwpose keypoint detection) # - face_alignment (bbox) # - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo) # - Whisper encoder (openai/whisper-tiny) # - face_parsing weights # Plus its preprocessing module has import-time side effects that # load dwpose weights from a CWD-relative path. Turn the full # pipeline on by extending this method once those deps are # installed and weights are resolved — until then, callers should # keep ``config.video.musetalk.enabled: false`` and VideoEngine # will skip the lip-sync pass. raise NotImplementedError( "MuseTalk lip-sync pipeline is not wired up yet. " "Set config.video.musetalk.enabled=false to bypass." ) def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray: """Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``. Repeats with a ping-pong / boomerang tail so the seam between loops is less jarring than a hard cut back to frame 0. """ if target_t <= 0: return frames t = len(frames) if t == target_t: return frames if t > target_t: return frames[:target_t] # Extend via ping-pong looping extended = [frames] total = t flip = True while total < target_t: seg = frames[::-1] if flip else frames extended.append(seg) total += t flip = not flip return np.concatenate(extended, axis=0)[:target_t] def _ensure_uint8_rgb(arr) -> np.ndarray: """Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB.""" result = np.asarray(arr) if result.dtype != np.uint8: if result.dtype in (np.float32, np.float64): result = np.clip(result * 255.0, 0, 255).astype(np.uint8) else: result = result.astype(np.uint8) if result.ndim == 3: result = result[None, ...] return result