first stab at adding video
This commit is contained in:
@@ -0,0 +1,164 @@
|
||||
"""MuseTalk audio-driven lip-sync wrapper.
|
||||
|
||||
MuseTalk takes a sequence of face frames + driving audio and returns a new
|
||||
sequence of frames where the mouth region is animated to match the audio.
|
||||
|
||||
This module isolates MuseTalk's real API behind a single ``lip_sync()``
|
||||
method. MuseTalk's upstream Python surface varies between forks — if the
|
||||
real import path or call signature differs, update this file only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MuseTalkEngine:
|
||||
"""Thin wrapper over MuseTalk inference."""
|
||||
|
||||
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
|
||||
self.model_path = model_path
|
||||
|
||||
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
|
||||
# similar ``MuseTalkInfer`` class. Try the most common imports.
|
||||
self._infer = self._load_impl(model_path)
|
||||
log.info("MuseTalk engine loaded from %s", model_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_impl(model_path: str):
|
||||
"""Load the MuseTalk inference implementation.
|
||||
|
||||
If none of the known entry points work the error message points at
|
||||
this file so you know where to fix it.
|
||||
"""
|
||||
resolved = model_path
|
||||
if not os.path.isdir(model_path) and "/" in model_path:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
resolved = snapshot_download(repo_id=model_path)
|
||||
except Exception as e: # pragma: no cover
|
||||
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
|
||||
|
||||
# Try upstream MuseTalk repo layout.
|
||||
try:
|
||||
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
|
||||
return MuseTalkInference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
|
||||
return MuseTalkInfer(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from musetalk import Inference # type: ignore[import-not-found]
|
||||
return Inference(model_path=resolved)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk is installed but no known Python entry point was found. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
|
||||
"to match the installed MuseTalk version."
|
||||
)
|
||||
|
||||
# --- Inference ---------------------------------------------------------
|
||||
|
||||
def lip_sync(
|
||||
self,
|
||||
frames: np.ndarray,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
fps: int,
|
||||
) -> np.ndarray:
|
||||
"""Return new frames with lip-sync applied to match ``audio``.
|
||||
|
||||
Args:
|
||||
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
|
||||
audio: float32 mono 1D audio.
|
||||
sample_rate: sample rate of ``audio``.
|
||||
fps: frame rate of ``frames``.
|
||||
|
||||
Returns:
|
||||
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
|
||||
to match audio duration at ``fps``.
|
||||
"""
|
||||
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
|
||||
)
|
||||
|
||||
# Normalise frame count to audio duration so the caller doesn't have
|
||||
# to do the arithmetic.
|
||||
target_t = int(round(len(audio) / sample_rate * fps))
|
||||
if target_t > 0 and len(frames) != target_t:
|
||||
frames = _fit_frames_to_length(frames, target_t)
|
||||
|
||||
# The real MuseTalk call signature varies. Most common is a method
|
||||
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
|
||||
for method_name in ("run", "infer", "lip_sync", "__call__"):
|
||||
method = getattr(self._infer, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
try:
|
||||
result = method(
|
||||
frames=frames,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
fps=fps,
|
||||
)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
# Try positional
|
||||
try:
|
||||
result = method(frames, audio, sample_rate, fps)
|
||||
return _ensure_uint8_rgb(result)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
"MuseTalk wrapper could not find a working inference method. "
|
||||
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
|
||||
)
|
||||
|
||||
|
||||
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
|
||||
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
|
||||
|
||||
Repeats with a ping-pong / boomerang tail so the seam between loops is
|
||||
less jarring than a hard cut back to frame 0.
|
||||
"""
|
||||
if target_t <= 0:
|
||||
return frames
|
||||
t = len(frames)
|
||||
if t == target_t:
|
||||
return frames
|
||||
if t > target_t:
|
||||
return frames[:target_t]
|
||||
# Extend via ping-pong looping
|
||||
extended = [frames]
|
||||
total = t
|
||||
flip = True
|
||||
while total < target_t:
|
||||
seg = frames[::-1] if flip else frames
|
||||
extended.append(seg)
|
||||
total += t
|
||||
flip = not flip
|
||||
return np.concatenate(extended, axis=0)[:target_t]
|
||||
|
||||
|
||||
def _ensure_uint8_rgb(arr) -> np.ndarray:
|
||||
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
|
||||
result = np.asarray(arr)
|
||||
if result.dtype != np.uint8:
|
||||
if result.dtype in (np.float32, np.float64):
|
||||
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
result = result.astype(np.uint8)
|
||||
if result.ndim == 3:
|
||||
result = result[None, ...]
|
||||
return result
|
||||
Reference in New Issue
Block a user