165 lines
5.7 KiB
Python
165 lines
5.7 KiB
Python
"""MuseTalk audio-driven lip-sync wrapper.
|
|
|
|
MuseTalk takes a sequence of face frames + driving audio and returns a new
|
|
sequence of frames where the mouth region is animated to match the audio.
|
|
|
|
This module isolates MuseTalk's real API behind a single ``lip_sync()``
|
|
method. MuseTalk's upstream Python surface varies between forks — if the
|
|
real import path or call signature differs, update this file only.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MuseTalkEngine:
|
|
"""Thin wrapper over MuseTalk inference."""
|
|
|
|
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
|
|
self.model_path = model_path
|
|
|
|
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
|
|
# similar ``MuseTalkInfer`` class. Try the most common imports.
|
|
self._infer = self._load_impl(model_path)
|
|
log.info("MuseTalk engine loaded from %s", model_path)
|
|
|
|
@staticmethod
|
|
def _load_impl(model_path: str):
|
|
"""Load the MuseTalk inference implementation.
|
|
|
|
If none of the known entry points work the error message points at
|
|
this file so you know where to fix it.
|
|
"""
|
|
resolved = model_path
|
|
if not os.path.isdir(model_path) and "/" in model_path:
|
|
try:
|
|
from huggingface_hub import snapshot_download
|
|
resolved = snapshot_download(repo_id=model_path)
|
|
except Exception as e: # pragma: no cover
|
|
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
|
|
|
# Try upstream MuseTalk repo layout.
|
|
try:
|
|
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
|
|
return MuseTalkInference(model_path=resolved)
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
|
|
return MuseTalkInfer(model_path=resolved)
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from musetalk import Inference # type: ignore[import-not-found]
|
|
return Inference(model_path=resolved)
|
|
except ImportError:
|
|
pass
|
|
|
|
raise RuntimeError(
|
|
"MuseTalk is installed but no known Python entry point was found. "
|
|
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
|
|
"to match the installed MuseTalk version."
|
|
)
|
|
|
|
# --- Inference ---------------------------------------------------------
|
|
|
|
def lip_sync(
|
|
self,
|
|
frames: np.ndarray,
|
|
audio: np.ndarray,
|
|
sample_rate: int,
|
|
fps: int,
|
|
) -> np.ndarray:
|
|
"""Return new frames with lip-sync applied to match ``audio``.
|
|
|
|
Args:
|
|
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
|
|
audio: float32 mono 1D audio.
|
|
sample_rate: sample rate of ``audio``.
|
|
fps: frame rate of ``frames``.
|
|
|
|
Returns:
|
|
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
|
|
to match audio duration at ``fps``.
|
|
"""
|
|
if frames.ndim != 4 or frames.shape[-1] != 3:
|
|
raise ValueError(
|
|
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
|
)
|
|
|
|
# Normalise frame count to audio duration so the caller doesn't have
|
|
# to do the arithmetic.
|
|
target_t = int(round(len(audio) / sample_rate * fps))
|
|
if target_t > 0 and len(frames) != target_t:
|
|
frames = _fit_frames_to_length(frames, target_t)
|
|
|
|
# The real MuseTalk call signature varies. Most common is a method
|
|
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
|
|
for method_name in ("run", "infer", "lip_sync", "__call__"):
|
|
method = getattr(self._infer, method_name, None)
|
|
if method is None:
|
|
continue
|
|
try:
|
|
result = method(
|
|
frames=frames,
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
fps=fps,
|
|
)
|
|
return _ensure_uint8_rgb(result)
|
|
except TypeError:
|
|
# Try positional
|
|
try:
|
|
result = method(frames, audio, sample_rate, fps)
|
|
return _ensure_uint8_rgb(result)
|
|
except TypeError:
|
|
continue
|
|
|
|
raise RuntimeError(
|
|
"MuseTalk wrapper could not find a working inference method. "
|
|
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
|
|
)
|
|
|
|
|
|
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
|
|
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
|
|
|
|
Repeats with a ping-pong / boomerang tail so the seam between loops is
|
|
less jarring than a hard cut back to frame 0.
|
|
"""
|
|
if target_t <= 0:
|
|
return frames
|
|
t = len(frames)
|
|
if t == target_t:
|
|
return frames
|
|
if t > target_t:
|
|
return frames[:target_t]
|
|
# Extend via ping-pong looping
|
|
extended = [frames]
|
|
total = t
|
|
flip = True
|
|
while total < target_t:
|
|
seg = frames[::-1] if flip else frames
|
|
extended.append(seg)
|
|
total += t
|
|
flip = not flip
|
|
return np.concatenate(extended, axis=0)[:target_t]
|
|
|
|
|
|
def _ensure_uint8_rgb(arr) -> np.ndarray:
|
|
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
|
|
result = np.asarray(arr)
|
|
if result.dtype != np.uint8:
|
|
if result.dtype in (np.float32, np.float64):
|
|
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
|
else:
|
|
result = result.astype(np.uint8)
|
|
if result.ndim == 3:
|
|
result = result[None, ...]
|
|
return result
|