66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
"""Unit tests for the frame-length fitting helper in server.video_models.musetalk.
|
|
|
|
Pure-python: does not import MuseTalk itself.
|
|
"""
|
|
import numpy as np
|
|
|
|
from server.video_models.musetalk import _fit_frames_to_length, _ensure_uint8_rgb
|
|
|
|
|
|
def _make_frames(t, h=2, w=2):
|
|
return np.arange(t * h * w * 3, dtype=np.uint8).reshape(t, h, w, 3)
|
|
|
|
|
|
def test_fit_frames_trim():
|
|
frames = _make_frames(10)
|
|
out = _fit_frames_to_length(frames, 4)
|
|
assert out.shape == (4, 2, 2, 3)
|
|
np.testing.assert_array_equal(out, frames[:4])
|
|
|
|
|
|
def test_fit_frames_passthrough_when_equal():
|
|
frames = _make_frames(5)
|
|
out = _fit_frames_to_length(frames, 5)
|
|
assert out is frames or np.array_equal(out, frames)
|
|
|
|
|
|
def test_fit_frames_extends_with_pingpong():
|
|
frames = _make_frames(3)
|
|
out = _fit_frames_to_length(frames, 8)
|
|
assert out.shape == (8, 2, 2, 3)
|
|
# First 3 frames match the original
|
|
np.testing.assert_array_equal(out[:3], frames)
|
|
# Next 3 are the reverse (ping-pong)
|
|
np.testing.assert_array_equal(out[3:6], frames[::-1])
|
|
# Then forward again
|
|
np.testing.assert_array_equal(out[6:8], frames[:2])
|
|
|
|
|
|
def test_fit_frames_zero_target_returns_original():
|
|
frames = _make_frames(3)
|
|
out = _fit_frames_to_length(frames, 0)
|
|
np.testing.assert_array_equal(out, frames)
|
|
|
|
|
|
def test_ensure_uint8_rgb_from_float():
|
|
arr = np.ones((5, 2, 2, 3), dtype=np.float32) * 0.5
|
|
out = _ensure_uint8_rgb(arr)
|
|
assert out.dtype == np.uint8
|
|
assert out.shape == (5, 2, 2, 3)
|
|
assert out[0, 0, 0, 0] == 127
|
|
|
|
|
|
def test_ensure_uint8_rgb_promotes_3d_to_4d():
|
|
arr = np.zeros((2, 2, 3), dtype=np.uint8)
|
|
out = _ensure_uint8_rgb(arr)
|
|
assert out.shape == (1, 2, 2, 3)
|
|
|
|
|
|
def test_ensure_uint8_rgb_clips_float_out_of_range():
|
|
arr = np.ones((1, 1, 1, 3), dtype=np.float32) * 2.0 # 2.0 → clipped to 255
|
|
out = _ensure_uint8_rgb(arr)
|
|
assert out[0, 0, 0, 0] == 255
|
|
arr2 = np.ones((1, 1, 1, 3), dtype=np.float32) * -1.0
|
|
out2 = _ensure_uint8_rgb(arr2)
|
|
assert out2[0, 0, 0, 0] == 0
|