"""MuseTalk audio-driven lip-sync wrapper. MuseTalk takes a sequence of face frames + driving audio and returns a new sequence of frames where the mouth region is animated to match the audio. This module isolates MuseTalk's real API behind a single ``lip_sync()`` method. MuseTalk's upstream Python surface varies between forks — if the real import path or call signature differs, update this file only. """ from __future__ import annotations import logging import os import numpy as np log = logging.getLogger(__name__) class MuseTalkEngine: """Thin wrapper over MuseTalk inference.""" def __init__(self, model_path: str = "TMElyralab/MuseTalk"): self.model_path = model_path # MuseTalk's canonical entry point is ``musetalk.inference`` or a # similar ``MuseTalkInfer`` class. Try the most common imports. self._infer = self._load_impl(model_path) log.info("MuseTalk engine loaded from %s", model_path) @staticmethod def _load_impl(model_path: str): """Load the MuseTalk inference implementation. If none of the known entry points work the error message points at this file so you know where to fix it. """ resolved = model_path if not os.path.isdir(model_path) and "/" in model_path: try: from huggingface_hub import snapshot_download resolved = snapshot_download(repo_id=model_path) except Exception as e: # pragma: no cover log.warning("Could not snapshot_download MuseTalk repo: %s", e) # Try upstream MuseTalk repo layout. try: from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found] return MuseTalkInference(model_path=resolved) except ImportError: pass try: from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found] return MuseTalkInfer(model_path=resolved) except ImportError: pass try: from musetalk import Inference # type: ignore[import-not-found] return Inference(model_path=resolved) except ImportError: pass raise RuntimeError( "MuseTalk is installed but no known Python entry point was found. " "Update server/video_models/musetalk.py::MuseTalkEngine._load_impl " "to match the installed MuseTalk version." ) # --- Inference --------------------------------------------------------- def lip_sync( self, frames: np.ndarray, audio: np.ndarray, sample_rate: int, fps: int, ) -> np.ndarray: """Return new frames with lip-sync applied to match ``audio``. Args: frames: uint8 ``[T, H, W, 3]`` RGB base frames. audio: float32 mono 1D audio. sample_rate: sample rate of ``audio``. fps: frame rate of ``frames``. Returns: uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded to match audio duration at ``fps``. """ if frames.ndim != 4 or frames.shape[-1] != 3: raise ValueError( f"frames must be [T, H, W, 3] uint8, got {frames.shape}" ) # Normalise frame count to audio duration so the caller doesn't have # to do the arithmetic. target_t = int(round(len(audio) / sample_rate * fps)) if target_t > 0 and len(frames) != target_t: frames = _fit_frames_to_length(frames, target_t) # The real MuseTalk call signature varies. Most common is a method # like ``run(frames, audio, sr, fps)`` or ``infer(...)``. for method_name in ("run", "infer", "lip_sync", "__call__"): method = getattr(self._infer, method_name, None) if method is None: continue try: result = method( frames=frames, audio=audio, sample_rate=sample_rate, fps=fps, ) return _ensure_uint8_rgb(result) except TypeError: # Try positional try: result = method(frames, audio, sample_rate, fps) return _ensure_uint8_rgb(result) except TypeError: continue raise RuntimeError( "MuseTalk wrapper could not find a working inference method. " "Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync." ) def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray: """Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``. Repeats with a ping-pong / boomerang tail so the seam between loops is less jarring than a hard cut back to frame 0. """ if target_t <= 0: return frames t = len(frames) if t == target_t: return frames if t > target_t: return frames[:target_t] # Extend via ping-pong looping extended = [frames] total = t flip = True while total < target_t: seg = frames[::-1] if flip else frames extended.append(seg) total += t flip = not flip return np.concatenate(extended, axis=0)[:target_t] def _ensure_uint8_rgb(arr) -> np.ndarray: """Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB.""" result = np.asarray(arr) if result.dtype != np.uint8: if result.dtype in (np.float32, np.float64): result = np.clip(result * 255.0, 0, 255).astype(np.uint8) else: result = result.astype(np.uint8) if result.ndim == 3: result = result[None, ...] return result